modeling_bert.py 55 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394
  1. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
  2. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch BERT model."""
  16. from collections.abc import Callable
  17. from dataclasses import dataclass
  18. import torch
  19. from torch import nn
  20. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  21. from ... import initialization as init
  22. from ...activations import ACT2FN
  23. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  24. from ...generation import GenerationMixin
  25. from ...masking_utils import create_bidirectional_mask, create_causal_mask
  26. from ...modeling_layers import GradientCheckpointingLayer
  27. from ...modeling_outputs import (
  28. BaseModelOutputWithPastAndCrossAttentions,
  29. BaseModelOutputWithPoolingAndCrossAttentions,
  30. CausalLMOutputWithCrossAttentions,
  31. MaskedLMOutput,
  32. MultipleChoiceModelOutput,
  33. NextSentencePredictorOutput,
  34. QuestionAnsweringModelOutput,
  35. SequenceClassifierOutput,
  36. TokenClassifierOutput,
  37. )
  38. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  39. from ...processing_utils import Unpack
  40. from ...pytorch_utils import apply_chunking_to_forward
  41. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, logging
  42. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  43. from ...utils.output_capturing import capture_outputs
  44. from .configuration_bert import BertConfig
  45. logger = logging.get_logger(__name__)
  46. class BertEmbeddings(nn.Module):
  47. """Construct the embeddings from word, position and token_type embeddings."""
  48. def __init__(self, config):
  49. super().__init__()
  50. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  51. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  52. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  53. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  54. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  55. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  56. self.register_buffer(
  57. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  58. )
  59. self.register_buffer(
  60. "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
  61. )
  62. def forward(
  63. self,
  64. input_ids: torch.LongTensor | None = None,
  65. token_type_ids: torch.LongTensor | None = None,
  66. position_ids: torch.LongTensor | None = None,
  67. inputs_embeds: torch.FloatTensor | None = None,
  68. past_key_values_length: int = 0,
  69. ) -> torch.Tensor:
  70. if input_ids is not None:
  71. input_shape = input_ids.size()
  72. else:
  73. input_shape = inputs_embeds.size()[:-1]
  74. batch_size, seq_length = input_shape
  75. if position_ids is None:
  76. position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
  77. # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
  78. # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
  79. # issue #5664
  80. if token_type_ids is None:
  81. if hasattr(self, "token_type_ids"):
  82. # NOTE: We assume either pos ids to have bsz == 1 (broadcastable) or bsz == effective bsz (input_shape[0])
  83. buffered_token_type_ids = self.token_type_ids.expand(position_ids.shape[0], -1)
  84. buffered_token_type_ids = torch.gather(buffered_token_type_ids, dim=1, index=position_ids)
  85. token_type_ids = buffered_token_type_ids.expand(batch_size, seq_length)
  86. else:
  87. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  88. if inputs_embeds is None:
  89. inputs_embeds = self.word_embeddings(input_ids)
  90. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  91. embeddings = inputs_embeds + token_type_embeddings
  92. position_embeddings = self.position_embeddings(position_ids)
  93. embeddings = embeddings + position_embeddings
  94. embeddings = self.LayerNorm(embeddings)
  95. embeddings = self.dropout(embeddings)
  96. return embeddings
  97. def eager_attention_forward(
  98. module: nn.Module,
  99. query: torch.Tensor,
  100. key: torch.Tensor,
  101. value: torch.Tensor,
  102. attention_mask: torch.Tensor | None,
  103. scaling: float | None = None,
  104. dropout: float = 0.0,
  105. **kwargs: Unpack[TransformersKwargs],
  106. ):
  107. if scaling is None:
  108. scaling = query.size(-1) ** -0.5
  109. # Take the dot product between "query" and "key" to get the raw attention scores.
  110. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  111. if attention_mask is not None:
  112. attn_weights = attn_weights + attention_mask
  113. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  114. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  115. attn_output = torch.matmul(attn_weights, value)
  116. attn_output = attn_output.transpose(1, 2).contiguous()
  117. return attn_output, attn_weights
  118. class BertSelfAttention(nn.Module):
  119. def __init__(self, config, is_causal=False, layer_idx=None):
  120. super().__init__()
  121. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  122. raise ValueError(
  123. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  124. f"heads ({config.num_attention_heads})"
  125. )
  126. self.config = config
  127. self.num_attention_heads = config.num_attention_heads
  128. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  129. self.all_head_size = self.num_attention_heads * self.attention_head_size
  130. self.scaling = self.attention_head_size**-0.5
  131. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  132. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  133. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  134. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  135. self.is_decoder = config.is_decoder
  136. self.is_causal = is_causal
  137. self.layer_idx = layer_idx
  138. def forward(
  139. self,
  140. hidden_states: torch.Tensor,
  141. attention_mask: torch.FloatTensor | None = None,
  142. past_key_values: Cache | None = None,
  143. **kwargs: Unpack[TransformersKwargs],
  144. ) -> tuple[torch.Tensor]:
  145. input_shape = hidden_states.shape[:-1]
  146. hidden_shape = (*input_shape, -1, self.attention_head_size)
  147. # get all proj
  148. query_layer = self.query(hidden_states).view(*hidden_shape).transpose(1, 2)
  149. key_layer = self.key(hidden_states).view(*hidden_shape).transpose(1, 2)
  150. value_layer = self.value(hidden_states).view(*hidden_shape).transpose(1, 2)
  151. if past_key_values is not None:
  152. # decoder-only bert can have a simple dynamic cache for example
  153. current_past_key_values = past_key_values
  154. if isinstance(past_key_values, EncoderDecoderCache):
  155. current_past_key_values = past_key_values.self_attention_cache
  156. # save all key/value_layer to cache to be re-used for fast auto-regressive generation
  157. key_layer, value_layer = current_past_key_values.update(key_layer, value_layer, self.layer_idx)
  158. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  159. self.config._attn_implementation, eager_attention_forward
  160. )
  161. attn_output, attn_weights = attention_interface(
  162. self,
  163. query_layer,
  164. key_layer,
  165. value_layer,
  166. attention_mask,
  167. dropout=0.0 if not self.training else self.dropout.p,
  168. scaling=self.scaling,
  169. **kwargs,
  170. )
  171. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  172. return attn_output, attn_weights
  173. class BertCrossAttention(nn.Module):
  174. def __init__(self, config, is_causal=False, layer_idx=None):
  175. super().__init__()
  176. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  177. raise ValueError(
  178. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  179. f"heads ({config.num_attention_heads})"
  180. )
  181. self.config = config
  182. self.num_attention_heads = config.num_attention_heads
  183. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  184. self.all_head_size = self.num_attention_heads * self.attention_head_size
  185. self.scaling = self.attention_head_size**-0.5
  186. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  187. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  188. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  189. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  190. self.is_causal = is_causal
  191. self.layer_idx = layer_idx
  192. def forward(
  193. self,
  194. hidden_states: torch.Tensor,
  195. encoder_hidden_states: torch.FloatTensor | None = None,
  196. attention_mask: torch.FloatTensor | None = None,
  197. past_key_values: EncoderDecoderCache | None = None,
  198. **kwargs: Unpack[TransformersKwargs],
  199. ) -> tuple[torch.Tensor]:
  200. # determine input shapes
  201. input_shape = hidden_states.shape[:-1]
  202. hidden_shape = (*input_shape, -1, self.attention_head_size)
  203. # get query proj
  204. query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  205. is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
  206. if past_key_values is not None and is_updated:
  207. # reuse k,v, cross_attentions
  208. key_layer = past_key_values.cross_attention_cache.layers[self.layer_idx].keys
  209. value_layer = past_key_values.cross_attention_cache.layers[self.layer_idx].values
  210. else:
  211. kv_shape = (*encoder_hidden_states.shape[:-1], -1, self.attention_head_size)
  212. key_layer = self.key(encoder_hidden_states).view(kv_shape).transpose(1, 2)
  213. value_layer = self.value(encoder_hidden_states).view(kv_shape).transpose(1, 2)
  214. if past_key_values is not None:
  215. # save all states to the cache
  216. key_layer, value_layer = past_key_values.cross_attention_cache.update(
  217. key_layer, value_layer, self.layer_idx
  218. )
  219. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  220. past_key_values.is_updated[self.layer_idx] = True
  221. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  222. self.config._attn_implementation, eager_attention_forward
  223. )
  224. attn_output, attn_weights = attention_interface(
  225. self,
  226. query_layer,
  227. key_layer,
  228. value_layer,
  229. attention_mask,
  230. dropout=0.0 if not self.training else self.dropout.p,
  231. scaling=self.scaling,
  232. **kwargs,
  233. )
  234. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  235. return attn_output, attn_weights
  236. class BertSelfOutput(nn.Module):
  237. def __init__(self, config):
  238. super().__init__()
  239. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  240. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  241. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  242. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  243. hidden_states = self.dense(hidden_states)
  244. hidden_states = self.dropout(hidden_states)
  245. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  246. return hidden_states
  247. class BertAttention(nn.Module):
  248. def __init__(self, config, is_causal=False, layer_idx=None, is_cross_attention=False):
  249. super().__init__()
  250. self.is_cross_attention = is_cross_attention
  251. attention_class = BertCrossAttention if is_cross_attention else BertSelfAttention
  252. self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx)
  253. self.output = BertSelfOutput(config)
  254. def forward(
  255. self,
  256. hidden_states: torch.Tensor,
  257. attention_mask: torch.FloatTensor | None = None,
  258. encoder_hidden_states: torch.FloatTensor | None = None,
  259. encoder_attention_mask: torch.FloatTensor | None = None,
  260. past_key_values: Cache | None = None,
  261. **kwargs: Unpack[TransformersKwargs],
  262. ) -> tuple[torch.Tensor]:
  263. attention_mask = attention_mask if not self.is_cross_attention else encoder_attention_mask
  264. attention_output, attn_weights = self.self(
  265. hidden_states,
  266. encoder_hidden_states=encoder_hidden_states,
  267. attention_mask=attention_mask,
  268. past_key_values=past_key_values,
  269. **kwargs,
  270. )
  271. attention_output = self.output(attention_output, hidden_states)
  272. return attention_output, attn_weights
  273. class BertIntermediate(nn.Module):
  274. def __init__(self, config):
  275. super().__init__()
  276. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  277. if isinstance(config.hidden_act, str):
  278. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  279. else:
  280. self.intermediate_act_fn = config.hidden_act
  281. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  282. hidden_states = self.dense(hidden_states)
  283. hidden_states = self.intermediate_act_fn(hidden_states)
  284. return hidden_states
  285. class BertOutput(nn.Module):
  286. def __init__(self, config):
  287. super().__init__()
  288. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  289. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  290. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  291. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  292. hidden_states = self.dense(hidden_states)
  293. hidden_states = self.dropout(hidden_states)
  294. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  295. return hidden_states
  296. class BertLayer(GradientCheckpointingLayer):
  297. def __init__(self, config, layer_idx=None):
  298. super().__init__()
  299. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  300. self.seq_len_dim = 1
  301. self.attention = BertAttention(config, is_causal=config.is_decoder, layer_idx=layer_idx)
  302. self.is_decoder = config.is_decoder
  303. self.add_cross_attention = config.add_cross_attention
  304. if self.add_cross_attention:
  305. if not self.is_decoder:
  306. raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
  307. self.crossattention = BertAttention(
  308. config,
  309. is_causal=False,
  310. layer_idx=layer_idx,
  311. is_cross_attention=True,
  312. )
  313. self.intermediate = BertIntermediate(config)
  314. self.output = BertOutput(config)
  315. def forward(
  316. self,
  317. hidden_states: torch.Tensor,
  318. attention_mask: torch.FloatTensor | None = None,
  319. encoder_hidden_states: torch.FloatTensor | None = None,
  320. encoder_attention_mask: torch.FloatTensor | None = None,
  321. past_key_values: Cache | None = None,
  322. **kwargs: Unpack[TransformersKwargs],
  323. ) -> torch.Tensor:
  324. self_attention_output, _ = self.attention(
  325. hidden_states,
  326. attention_mask,
  327. past_key_values=past_key_values,
  328. **kwargs,
  329. )
  330. attention_output = self_attention_output
  331. if self.is_decoder and encoder_hidden_states is not None:
  332. if not hasattr(self, "crossattention"):
  333. raise ValueError(
  334. f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
  335. " by setting `config.add_cross_attention=True`"
  336. )
  337. cross_attention_output, _ = self.crossattention(
  338. self_attention_output,
  339. None, # attention_mask
  340. encoder_hidden_states,
  341. encoder_attention_mask,
  342. past_key_values=past_key_values,
  343. **kwargs,
  344. )
  345. attention_output = cross_attention_output
  346. layer_output = apply_chunking_to_forward(
  347. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  348. )
  349. return layer_output
  350. def feed_forward_chunk(self, attention_output):
  351. intermediate_output = self.intermediate(attention_output)
  352. layer_output = self.output(intermediate_output, attention_output)
  353. return layer_output
  354. class BertEncoder(nn.Module):
  355. def __init__(self, config):
  356. super().__init__()
  357. self.config = config
  358. self.layer = nn.ModuleList([BertLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
  359. def forward(
  360. self,
  361. hidden_states: torch.Tensor,
  362. attention_mask: torch.FloatTensor | None = None,
  363. encoder_hidden_states: torch.FloatTensor | None = None,
  364. encoder_attention_mask: torch.FloatTensor | None = None,
  365. past_key_values: Cache | None = None,
  366. use_cache: bool | None = None,
  367. **kwargs: Unpack[TransformersKwargs],
  368. ) -> tuple[torch.Tensor] | BaseModelOutputWithPastAndCrossAttentions:
  369. for i, layer_module in enumerate(self.layer):
  370. hidden_states = layer_module(
  371. hidden_states,
  372. attention_mask,
  373. encoder_hidden_states, # as a positional argument for gradient checkpointing
  374. encoder_attention_mask=encoder_attention_mask,
  375. past_key_values=past_key_values,
  376. **kwargs,
  377. )
  378. return BaseModelOutputWithPastAndCrossAttentions(
  379. last_hidden_state=hidden_states,
  380. past_key_values=past_key_values if use_cache else None,
  381. )
  382. class BertPooler(nn.Module):
  383. def __init__(self, config):
  384. super().__init__()
  385. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  386. self.activation = nn.Tanh()
  387. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  388. # We "pool" the model by simply taking the hidden state corresponding
  389. # to the first token.
  390. first_token_tensor = hidden_states[:, 0]
  391. pooled_output = self.dense(first_token_tensor)
  392. pooled_output = self.activation(pooled_output)
  393. return pooled_output
  394. class BertPredictionHeadTransform(nn.Module):
  395. def __init__(self, config):
  396. super().__init__()
  397. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  398. if isinstance(config.hidden_act, str):
  399. self.transform_act_fn = ACT2FN[config.hidden_act]
  400. else:
  401. self.transform_act_fn = config.hidden_act
  402. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  403. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  404. hidden_states = self.dense(hidden_states)
  405. hidden_states = self.transform_act_fn(hidden_states)
  406. hidden_states = self.LayerNorm(hidden_states)
  407. return hidden_states
  408. class BertLMPredictionHead(nn.Module):
  409. def __init__(self, config):
  410. super().__init__()
  411. self.transform = BertPredictionHeadTransform(config)
  412. # The output weights are the same as the input embeddings, but there is
  413. # an output-only bias for each token.
  414. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
  415. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  416. def forward(self, hidden_states):
  417. hidden_states = self.transform(hidden_states)
  418. hidden_states = self.decoder(hidden_states)
  419. return hidden_states
  420. class BertOnlyMLMHead(nn.Module):
  421. def __init__(self, config):
  422. super().__init__()
  423. self.predictions = BertLMPredictionHead(config)
  424. def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
  425. prediction_scores = self.predictions(sequence_output)
  426. return prediction_scores
  427. class BertOnlyNSPHead(nn.Module):
  428. def __init__(self, config):
  429. super().__init__()
  430. self.seq_relationship = nn.Linear(config.hidden_size, 2)
  431. def forward(self, pooled_output):
  432. seq_relationship_score = self.seq_relationship(pooled_output)
  433. return seq_relationship_score
  434. class BertPreTrainingHeads(nn.Module):
  435. def __init__(self, config):
  436. super().__init__()
  437. self.predictions = BertLMPredictionHead(config)
  438. self.seq_relationship = nn.Linear(config.hidden_size, 2)
  439. def forward(self, sequence_output, pooled_output):
  440. prediction_scores = self.predictions(sequence_output)
  441. seq_relationship_score = self.seq_relationship(pooled_output)
  442. return prediction_scores, seq_relationship_score
  443. @auto_docstring
  444. class BertPreTrainedModel(PreTrainedModel):
  445. config_class = BertConfig
  446. base_model_prefix = "bert"
  447. supports_gradient_checkpointing = True
  448. _supports_flash_attn = True
  449. _supports_sdpa = True
  450. _supports_flex_attn = True
  451. _supports_attention_backend = True
  452. _can_record_outputs = {
  453. "hidden_states": BertLayer,
  454. "attentions": BertSelfAttention,
  455. "cross_attentions": BertCrossAttention,
  456. }
  457. @torch.no_grad()
  458. def _init_weights(self, module):
  459. """Initialize the weights"""
  460. super()._init_weights(module)
  461. if isinstance(module, BertLMPredictionHead):
  462. init.zeros_(module.bias)
  463. elif isinstance(module, BertEmbeddings):
  464. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  465. init.zeros_(module.token_type_ids)
  466. @dataclass
  467. @auto_docstring(
  468. custom_intro="""
  469. Output type of [`BertForPreTraining`].
  470. """
  471. )
  472. class BertForPreTrainingOutput(ModelOutput):
  473. r"""
  474. loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  475. Total loss as the sum of the masked language modeling loss and the next sequence prediction
  476. (classification) loss.
  477. prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  478. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  479. seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
  480. Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
  481. before SoftMax).
  482. """
  483. loss: torch.FloatTensor | None = None
  484. prediction_logits: torch.FloatTensor | None = None
  485. seq_relationship_logits: torch.FloatTensor | None = None
  486. hidden_states: tuple[torch.FloatTensor] | None = None
  487. attentions: tuple[torch.FloatTensor] | None = None
  488. @auto_docstring(
  489. custom_intro="""
  490. The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
  491. cross-attention is added between the self-attention layers, following the architecture described in [Attention is
  492. all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
  493. Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
  494. To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
  495. to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
  496. `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
  497. """
  498. )
  499. class BertModel(BertPreTrainedModel):
  500. _no_split_modules = ["BertEmbeddings", "BertLayer"]
  501. def __init__(self, config, add_pooling_layer=True):
  502. r"""
  503. add_pooling_layer (bool, *optional*, defaults to `True`):
  504. Whether to add a pooling layer
  505. """
  506. super().__init__(config)
  507. self.config = config
  508. self.gradient_checkpointing = False
  509. self.embeddings = BertEmbeddings(config)
  510. self.encoder = BertEncoder(config)
  511. self.pooler = BertPooler(config) if add_pooling_layer else None
  512. # Initialize weights and apply final processing
  513. self.post_init()
  514. def get_input_embeddings(self):
  515. return self.embeddings.word_embeddings
  516. def set_input_embeddings(self, value):
  517. self.embeddings.word_embeddings = value
  518. @merge_with_config_defaults
  519. @capture_outputs
  520. @auto_docstring
  521. def forward(
  522. self,
  523. input_ids: torch.Tensor | None = None,
  524. attention_mask: torch.Tensor | None = None,
  525. token_type_ids: torch.Tensor | None = None,
  526. position_ids: torch.Tensor | None = None,
  527. inputs_embeds: torch.Tensor | None = None,
  528. encoder_hidden_states: torch.Tensor | None = None,
  529. encoder_attention_mask: torch.Tensor | None = None,
  530. past_key_values: Cache | None = None,
  531. use_cache: bool | None = None,
  532. **kwargs: Unpack[TransformersKwargs],
  533. ) -> tuple[torch.Tensor] | BaseModelOutputWithPoolingAndCrossAttentions:
  534. if (input_ids is None) ^ (inputs_embeds is not None):
  535. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  536. if self.config.is_decoder:
  537. use_cache = use_cache if use_cache is not None else self.config.use_cache
  538. else:
  539. use_cache = False
  540. if use_cache and past_key_values is None:
  541. past_key_values = (
  542. EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  543. if encoder_hidden_states is not None or self.config.is_encoder_decoder
  544. else DynamicCache(config=self.config)
  545. )
  546. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  547. embedding_output = self.embeddings(
  548. input_ids=input_ids,
  549. position_ids=position_ids,
  550. token_type_ids=token_type_ids,
  551. inputs_embeds=inputs_embeds,
  552. past_key_values_length=past_key_values_length,
  553. )
  554. attention_mask, encoder_attention_mask = self._create_attention_masks(
  555. attention_mask=attention_mask,
  556. encoder_attention_mask=encoder_attention_mask,
  557. embedding_output=embedding_output,
  558. encoder_hidden_states=encoder_hidden_states,
  559. past_key_values=past_key_values,
  560. )
  561. encoder_outputs = self.encoder(
  562. embedding_output,
  563. attention_mask=attention_mask,
  564. encoder_hidden_states=encoder_hidden_states,
  565. encoder_attention_mask=encoder_attention_mask,
  566. past_key_values=past_key_values,
  567. use_cache=use_cache,
  568. position_ids=position_ids,
  569. **kwargs,
  570. )
  571. sequence_output = encoder_outputs.last_hidden_state
  572. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  573. return BaseModelOutputWithPoolingAndCrossAttentions(
  574. last_hidden_state=sequence_output,
  575. pooler_output=pooled_output,
  576. past_key_values=encoder_outputs.past_key_values,
  577. )
  578. def _create_attention_masks(
  579. self,
  580. attention_mask,
  581. encoder_attention_mask,
  582. embedding_output,
  583. encoder_hidden_states,
  584. past_key_values,
  585. ):
  586. if self.config.is_decoder:
  587. attention_mask = create_causal_mask(
  588. config=self.config,
  589. inputs_embeds=embedding_output,
  590. attention_mask=attention_mask,
  591. past_key_values=past_key_values,
  592. )
  593. else:
  594. attention_mask = create_bidirectional_mask(
  595. config=self.config,
  596. inputs_embeds=embedding_output,
  597. attention_mask=attention_mask,
  598. )
  599. if encoder_attention_mask is not None:
  600. encoder_attention_mask = create_bidirectional_mask(
  601. config=self.config,
  602. inputs_embeds=embedding_output,
  603. attention_mask=encoder_attention_mask,
  604. encoder_hidden_states=encoder_hidden_states,
  605. )
  606. return attention_mask, encoder_attention_mask
  607. @auto_docstring(
  608. custom_intro="""
  609. Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
  610. sentence prediction (classification)` head.
  611. """
  612. )
  613. class BertForPreTraining(BertPreTrainedModel):
  614. _tied_weights_keys = {
  615. "cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight",
  616. "cls.predictions.decoder.bias": "cls.predictions.bias",
  617. }
  618. def __init__(self, config):
  619. super().__init__(config)
  620. self.bert = BertModel(config)
  621. self.cls = BertPreTrainingHeads(config)
  622. # Initialize weights and apply final processing
  623. self.post_init()
  624. def get_output_embeddings(self):
  625. return self.cls.predictions.decoder
  626. def set_output_embeddings(self, new_embeddings):
  627. self.cls.predictions.decoder = new_embeddings
  628. self.cls.predictions.bias = new_embeddings.bias
  629. @can_return_tuple
  630. @auto_docstring
  631. def forward(
  632. self,
  633. input_ids: torch.Tensor | None = None,
  634. attention_mask: torch.Tensor | None = None,
  635. token_type_ids: torch.Tensor | None = None,
  636. position_ids: torch.Tensor | None = None,
  637. inputs_embeds: torch.Tensor | None = None,
  638. labels: torch.Tensor | None = None,
  639. next_sentence_label: torch.Tensor | None = None,
  640. **kwargs: Unpack[TransformersKwargs],
  641. ) -> tuple[torch.Tensor] | BertForPreTrainingOutput:
  642. r"""
  643. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  644. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  645. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked),
  646. the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  647. next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  648. Labels for computing the next sequence prediction (classification) loss. Input should be a sequence
  649. pair (see `input_ids` docstring) Indices should be in `[0, 1]`:
  650. - 0 indicates sequence B is a continuation of sequence A,
  651. - 1 indicates sequence B is a random sequence.
  652. Example:
  653. ```python
  654. >>> from transformers import AutoTokenizer, BertForPreTraining
  655. >>> import torch
  656. >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
  657. >>> model = BertForPreTraining.from_pretrained("google-bert/bert-base-uncased")
  658. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  659. >>> outputs = model(**inputs)
  660. >>> prediction_logits = outputs.prediction_logits
  661. >>> seq_relationship_logits = outputs.seq_relationship_logits
  662. ```
  663. """
  664. outputs = self.bert(
  665. input_ids,
  666. attention_mask=attention_mask,
  667. token_type_ids=token_type_ids,
  668. position_ids=position_ids,
  669. inputs_embeds=inputs_embeds,
  670. return_dict=True,
  671. **kwargs,
  672. )
  673. sequence_output, pooled_output = outputs[:2]
  674. prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
  675. total_loss = None
  676. if labels is not None and next_sentence_label is not None:
  677. loss_fct = CrossEntropyLoss()
  678. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  679. next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
  680. total_loss = masked_lm_loss + next_sentence_loss
  681. return BertForPreTrainingOutput(
  682. loss=total_loss,
  683. prediction_logits=prediction_scores,
  684. seq_relationship_logits=seq_relationship_score,
  685. hidden_states=outputs.hidden_states,
  686. attentions=outputs.attentions,
  687. )
  688. @auto_docstring(
  689. custom_intro="""
  690. Bert Model with a `language modeling` head on top for CLM fine-tuning.
  691. """
  692. )
  693. class BertLMHeadModel(BertPreTrainedModel, GenerationMixin):
  694. _tied_weights_keys = {
  695. "cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight",
  696. "cls.predictions.decoder.bias": "cls.predictions.bias",
  697. }
  698. def __init__(self, config):
  699. super().__init__(config)
  700. if not config.is_decoder:
  701. logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`")
  702. self.bert = BertModel(config, add_pooling_layer=False)
  703. self.cls = BertOnlyMLMHead(config)
  704. # Initialize weights and apply final processing
  705. self.post_init()
  706. def get_output_embeddings(self):
  707. return self.cls.predictions.decoder
  708. def set_output_embeddings(self, new_embeddings):
  709. self.cls.predictions.decoder = new_embeddings
  710. self.cls.predictions.bias = new_embeddings.bias
  711. @can_return_tuple
  712. @auto_docstring
  713. def forward(
  714. self,
  715. input_ids: torch.Tensor | None = None,
  716. attention_mask: torch.Tensor | None = None,
  717. token_type_ids: torch.Tensor | None = None,
  718. position_ids: torch.Tensor | None = None,
  719. inputs_embeds: torch.Tensor | None = None,
  720. encoder_hidden_states: torch.Tensor | None = None,
  721. encoder_attention_mask: torch.Tensor | None = None,
  722. labels: torch.Tensor | None = None,
  723. past_key_values: Cache | None = None,
  724. use_cache: bool | None = None,
  725. logits_to_keep: int | torch.Tensor = 0,
  726. **kwargs: Unpack[TransformersKwargs],
  727. ) -> tuple[torch.Tensor] | CausalLMOutputWithCrossAttentions:
  728. r"""
  729. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  730. Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
  731. `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
  732. ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
  733. """
  734. if labels is not None:
  735. use_cache = False
  736. outputs: BaseModelOutputWithPoolingAndCrossAttentions = self.bert(
  737. input_ids,
  738. attention_mask=attention_mask,
  739. token_type_ids=token_type_ids,
  740. position_ids=position_ids,
  741. inputs_embeds=inputs_embeds,
  742. encoder_hidden_states=encoder_hidden_states,
  743. encoder_attention_mask=encoder_attention_mask,
  744. past_key_values=past_key_values,
  745. use_cache=use_cache,
  746. return_dict=True,
  747. **kwargs,
  748. )
  749. hidden_states = outputs.last_hidden_state
  750. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  751. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  752. logits = self.cls(hidden_states[:, slice_indices, :])
  753. loss = None
  754. if labels is not None:
  755. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  756. return CausalLMOutputWithCrossAttentions(
  757. loss=loss,
  758. logits=logits,
  759. past_key_values=outputs.past_key_values,
  760. hidden_states=outputs.hidden_states,
  761. attentions=outputs.attentions,
  762. cross_attentions=outputs.cross_attentions,
  763. )
  764. @auto_docstring
  765. class BertForMaskedLM(BertPreTrainedModel):
  766. _tied_weights_keys = {
  767. "cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight",
  768. "cls.predictions.decoder.bias": "cls.predictions.bias",
  769. }
  770. def __init__(self, config):
  771. super().__init__(config)
  772. if config.is_decoder:
  773. logger.warning(
  774. "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
  775. "bi-directional self-attention."
  776. )
  777. self.bert = BertModel(config, add_pooling_layer=False)
  778. self.cls = BertOnlyMLMHead(config)
  779. # Initialize weights and apply final processing
  780. self.post_init()
  781. def get_output_embeddings(self):
  782. return self.cls.predictions.decoder
  783. def set_output_embeddings(self, new_embeddings):
  784. self.cls.predictions.decoder = new_embeddings
  785. self.cls.predictions.bias = new_embeddings.bias
  786. @can_return_tuple
  787. @auto_docstring
  788. def forward(
  789. self,
  790. input_ids: torch.Tensor | None = None,
  791. attention_mask: torch.Tensor | None = None,
  792. token_type_ids: torch.Tensor | None = None,
  793. position_ids: torch.Tensor | None = None,
  794. inputs_embeds: torch.Tensor | None = None,
  795. encoder_hidden_states: torch.Tensor | None = None,
  796. encoder_attention_mask: torch.Tensor | None = None,
  797. labels: torch.Tensor | None = None,
  798. **kwargs: Unpack[TransformersKwargs],
  799. ) -> tuple[torch.Tensor] | MaskedLMOutput:
  800. r"""
  801. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  802. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  803. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  804. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  805. """
  806. outputs = self.bert(
  807. input_ids,
  808. attention_mask=attention_mask,
  809. token_type_ids=token_type_ids,
  810. position_ids=position_ids,
  811. inputs_embeds=inputs_embeds,
  812. encoder_hidden_states=encoder_hidden_states,
  813. encoder_attention_mask=encoder_attention_mask,
  814. return_dict=True,
  815. **kwargs,
  816. )
  817. sequence_output = outputs[0]
  818. prediction_scores = self.cls(sequence_output)
  819. masked_lm_loss = None
  820. if labels is not None:
  821. loss_fct = CrossEntropyLoss() # -100 index = padding token
  822. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  823. return MaskedLMOutput(
  824. loss=masked_lm_loss,
  825. logits=prediction_scores,
  826. hidden_states=outputs.hidden_states,
  827. attentions=outputs.attentions,
  828. )
  829. @auto_docstring(
  830. custom_intro="""
  831. Bert Model with a `next sentence prediction (classification)` head on top.
  832. """
  833. )
  834. class BertForNextSentencePrediction(BertPreTrainedModel):
  835. def __init__(self, config):
  836. super().__init__(config)
  837. self.bert = BertModel(config)
  838. self.cls = BertOnlyNSPHead(config)
  839. # Initialize weights and apply final processing
  840. self.post_init()
  841. @can_return_tuple
  842. @auto_docstring
  843. def forward(
  844. self,
  845. input_ids: torch.Tensor | None = None,
  846. attention_mask: torch.Tensor | None = None,
  847. token_type_ids: torch.Tensor | None = None,
  848. position_ids: torch.Tensor | None = None,
  849. inputs_embeds: torch.Tensor | None = None,
  850. labels: torch.Tensor | None = None,
  851. **kwargs: Unpack[TransformersKwargs],
  852. ) -> tuple[torch.Tensor] | NextSentencePredictorOutput:
  853. r"""
  854. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  855. Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
  856. (see `input_ids` docstring). Indices should be in `[0, 1]`:
  857. - 0 indicates sequence B is a continuation of sequence A,
  858. - 1 indicates sequence B is a random sequence.
  859. Example:
  860. ```python
  861. >>> from transformers import AutoTokenizer, BertForNextSentencePrediction
  862. >>> import torch
  863. >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
  864. >>> model = BertForNextSentencePrediction.from_pretrained("google-bert/bert-base-uncased")
  865. >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
  866. >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
  867. >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
  868. >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
  869. >>> logits = outputs.logits
  870. >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
  871. ```
  872. """
  873. outputs = self.bert(
  874. input_ids,
  875. attention_mask=attention_mask,
  876. token_type_ids=token_type_ids,
  877. position_ids=position_ids,
  878. inputs_embeds=inputs_embeds,
  879. return_dict=True,
  880. **kwargs,
  881. )
  882. pooled_output = outputs[1]
  883. seq_relationship_scores = self.cls(pooled_output)
  884. next_sentence_loss = None
  885. if labels is not None:
  886. loss_fct = CrossEntropyLoss()
  887. next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
  888. return NextSentencePredictorOutput(
  889. loss=next_sentence_loss,
  890. logits=seq_relationship_scores,
  891. hidden_states=outputs.hidden_states,
  892. attentions=outputs.attentions,
  893. )
  894. @auto_docstring(
  895. custom_intro="""
  896. Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
  897. output) e.g. for GLUE tasks.
  898. """
  899. )
  900. class BertForSequenceClassification(BertPreTrainedModel):
  901. def __init__(self, config):
  902. super().__init__(config)
  903. self.num_labels = config.num_labels
  904. self.config = config
  905. self.bert = BertModel(config)
  906. classifier_dropout = (
  907. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  908. )
  909. self.dropout = nn.Dropout(classifier_dropout)
  910. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  911. # Initialize weights and apply final processing
  912. self.post_init()
  913. @can_return_tuple
  914. @auto_docstring
  915. def forward(
  916. self,
  917. input_ids: torch.Tensor | None = None,
  918. attention_mask: torch.Tensor | None = None,
  919. token_type_ids: torch.Tensor | None = None,
  920. position_ids: torch.Tensor | None = None,
  921. inputs_embeds: torch.Tensor | None = None,
  922. labels: torch.Tensor | None = None,
  923. **kwargs: Unpack[TransformersKwargs],
  924. ) -> tuple[torch.Tensor] | SequenceClassifierOutput:
  925. r"""
  926. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  927. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  928. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  929. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  930. """
  931. outputs = self.bert(
  932. input_ids,
  933. attention_mask=attention_mask,
  934. token_type_ids=token_type_ids,
  935. position_ids=position_ids,
  936. inputs_embeds=inputs_embeds,
  937. return_dict=True,
  938. **kwargs,
  939. )
  940. pooled_output = outputs[1]
  941. pooled_output = self.dropout(pooled_output)
  942. logits = self.classifier(pooled_output)
  943. loss = None
  944. if labels is not None:
  945. if self.config.problem_type is None:
  946. if self.num_labels == 1:
  947. self.config.problem_type = "regression"
  948. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  949. self.config.problem_type = "single_label_classification"
  950. else:
  951. self.config.problem_type = "multi_label_classification"
  952. if self.config.problem_type == "regression":
  953. loss_fct = MSELoss()
  954. if self.num_labels == 1:
  955. loss = loss_fct(logits.squeeze(), labels.squeeze())
  956. else:
  957. loss = loss_fct(logits, labels)
  958. elif self.config.problem_type == "single_label_classification":
  959. loss_fct = CrossEntropyLoss()
  960. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  961. elif self.config.problem_type == "multi_label_classification":
  962. loss_fct = BCEWithLogitsLoss()
  963. loss = loss_fct(logits, labels)
  964. return SequenceClassifierOutput(
  965. loss=loss,
  966. logits=logits,
  967. hidden_states=outputs.hidden_states,
  968. attentions=outputs.attentions,
  969. )
  970. @auto_docstring
  971. class BertForMultipleChoice(BertPreTrainedModel):
  972. def __init__(self, config):
  973. super().__init__(config)
  974. self.bert = BertModel(config)
  975. classifier_dropout = (
  976. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  977. )
  978. self.dropout = nn.Dropout(classifier_dropout)
  979. self.classifier = nn.Linear(config.hidden_size, 1)
  980. # Initialize weights and apply final processing
  981. self.post_init()
  982. @can_return_tuple
  983. @auto_docstring
  984. def forward(
  985. self,
  986. input_ids: torch.Tensor | None = None,
  987. attention_mask: torch.Tensor | None = None,
  988. token_type_ids: torch.Tensor | None = None,
  989. position_ids: torch.Tensor | None = None,
  990. inputs_embeds: torch.Tensor | None = None,
  991. labels: torch.Tensor | None = None,
  992. **kwargs: Unpack[TransformersKwargs],
  993. ) -> tuple[torch.Tensor] | MultipleChoiceModelOutput:
  994. r"""
  995. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  996. Indices of input sequence tokens in the vocabulary.
  997. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  998. [`PreTrainedTokenizer.__call__`] for details.
  999. [What are input IDs?](../glossary#input-ids)
  1000. token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  1001. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  1002. 1]`:
  1003. - 0 corresponds to a *sentence A* token,
  1004. - 1 corresponds to a *sentence B* token.
  1005. [What are token type IDs?](../glossary#token-type-ids)
  1006. position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  1007. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  1008. config.max_position_embeddings - 1]`.
  1009. [What are position IDs?](../glossary#position-ids)
  1010. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  1011. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  1012. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  1013. model's internal embedding lookup matrix.
  1014. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1015. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  1016. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  1017. `input_ids` above)
  1018. """
  1019. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  1020. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  1021. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  1022. token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  1023. position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  1024. inputs_embeds = (
  1025. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  1026. if inputs_embeds is not None
  1027. else None
  1028. )
  1029. outputs = self.bert(
  1030. input_ids,
  1031. attention_mask=attention_mask,
  1032. token_type_ids=token_type_ids,
  1033. position_ids=position_ids,
  1034. inputs_embeds=inputs_embeds,
  1035. return_dict=True,
  1036. **kwargs,
  1037. )
  1038. pooled_output = outputs[1]
  1039. pooled_output = self.dropout(pooled_output)
  1040. logits = self.classifier(pooled_output)
  1041. reshaped_logits = logits.view(-1, num_choices)
  1042. loss = None
  1043. if labels is not None:
  1044. loss_fct = CrossEntropyLoss()
  1045. loss = loss_fct(reshaped_logits, labels)
  1046. return MultipleChoiceModelOutput(
  1047. loss=loss,
  1048. logits=reshaped_logits,
  1049. hidden_states=outputs.hidden_states,
  1050. attentions=outputs.attentions,
  1051. )
  1052. @auto_docstring
  1053. class BertForTokenClassification(BertPreTrainedModel):
  1054. def __init__(self, config):
  1055. super().__init__(config)
  1056. self.num_labels = config.num_labels
  1057. self.bert = BertModel(config, add_pooling_layer=False)
  1058. classifier_dropout = (
  1059. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  1060. )
  1061. self.dropout = nn.Dropout(classifier_dropout)
  1062. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1063. # Initialize weights and apply final processing
  1064. self.post_init()
  1065. @can_return_tuple
  1066. @auto_docstring
  1067. def forward(
  1068. self,
  1069. input_ids: torch.Tensor | None = None,
  1070. attention_mask: torch.Tensor | None = None,
  1071. token_type_ids: torch.Tensor | None = None,
  1072. position_ids: torch.Tensor | None = None,
  1073. inputs_embeds: torch.Tensor | None = None,
  1074. labels: torch.Tensor | None = None,
  1075. **kwargs: Unpack[TransformersKwargs],
  1076. ) -> tuple[torch.Tensor] | TokenClassifierOutput:
  1077. r"""
  1078. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1079. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1080. """
  1081. outputs = self.bert(
  1082. input_ids,
  1083. attention_mask=attention_mask,
  1084. token_type_ids=token_type_ids,
  1085. position_ids=position_ids,
  1086. inputs_embeds=inputs_embeds,
  1087. return_dict=True,
  1088. **kwargs,
  1089. )
  1090. sequence_output = outputs[0]
  1091. sequence_output = self.dropout(sequence_output)
  1092. logits = self.classifier(sequence_output)
  1093. loss = None
  1094. if labels is not None:
  1095. loss_fct = CrossEntropyLoss()
  1096. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1097. return TokenClassifierOutput(
  1098. loss=loss,
  1099. logits=logits,
  1100. hidden_states=outputs.hidden_states,
  1101. attentions=outputs.attentions,
  1102. )
  1103. @auto_docstring
  1104. class BertForQuestionAnswering(BertPreTrainedModel):
  1105. def __init__(self, config):
  1106. super().__init__(config)
  1107. self.num_labels = config.num_labels
  1108. self.bert = BertModel(config, add_pooling_layer=False)
  1109. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  1110. # Initialize weights and apply final processing
  1111. self.post_init()
  1112. @can_return_tuple
  1113. @auto_docstring
  1114. def forward(
  1115. self,
  1116. input_ids: torch.Tensor | None = None,
  1117. attention_mask: torch.Tensor | None = None,
  1118. token_type_ids: torch.Tensor | None = None,
  1119. position_ids: torch.Tensor | None = None,
  1120. inputs_embeds: torch.Tensor | None = None,
  1121. start_positions: torch.Tensor | None = None,
  1122. end_positions: torch.Tensor | None = None,
  1123. **kwargs: Unpack[TransformersKwargs],
  1124. ) -> tuple[torch.Tensor] | QuestionAnsweringModelOutput:
  1125. outputs = self.bert(
  1126. input_ids,
  1127. attention_mask=attention_mask,
  1128. token_type_ids=token_type_ids,
  1129. position_ids=position_ids,
  1130. inputs_embeds=inputs_embeds,
  1131. return_dict=True,
  1132. **kwargs,
  1133. )
  1134. sequence_output = outputs[0]
  1135. logits = self.qa_outputs(sequence_output)
  1136. start_logits, end_logits = logits.split(1, dim=-1)
  1137. start_logits = start_logits.squeeze(-1).contiguous()
  1138. end_logits = end_logits.squeeze(-1).contiguous()
  1139. total_loss = None
  1140. if start_positions is not None and end_positions is not None:
  1141. # If we are on multi-GPU, split add a dimension
  1142. if len(start_positions.size()) > 1:
  1143. start_positions = start_positions.squeeze(-1)
  1144. if len(end_positions.size()) > 1:
  1145. end_positions = end_positions.squeeze(-1)
  1146. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1147. ignored_index = start_logits.size(1)
  1148. start_positions = start_positions.clamp(0, ignored_index)
  1149. end_positions = end_positions.clamp(0, ignored_index)
  1150. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1151. start_loss = loss_fct(start_logits, start_positions)
  1152. end_loss = loss_fct(end_logits, end_positions)
  1153. total_loss = (start_loss + end_loss) / 2
  1154. return QuestionAnsweringModelOutput(
  1155. loss=total_loss,
  1156. start_logits=start_logits,
  1157. end_logits=end_logits,
  1158. hidden_states=outputs.hidden_states,
  1159. attentions=outputs.attentions,
  1160. )
  1161. __all__ = [
  1162. "BertForMaskedLM",
  1163. "BertForMultipleChoice",
  1164. "BertForNextSentencePrediction",
  1165. "BertForPreTraining",
  1166. "BertForQuestionAnswering",
  1167. "BertForSequenceClassification",
  1168. "BertForTokenClassification",
  1169. "BertLayer",
  1170. "BertLMHeadModel",
  1171. "BertModel",
  1172. "BertPreTrainedModel",
  1173. ]