modeling_camembert.py 52 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/camembert/modular_camembert.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_camembert.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2019 Inria, Facebook AI Research and the HuggingFace Inc. team.
  8. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. from collections.abc import Callable
  22. import torch
  23. import torch.nn as nn
  24. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  25. from ... import initialization as init
  26. from ...activations import ACT2FN, gelu
  27. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  28. from ...generation import GenerationMixin
  29. from ...masking_utils import create_bidirectional_mask, create_causal_mask
  30. from ...modeling_layers import GradientCheckpointingLayer
  31. from ...modeling_outputs import (
  32. BaseModelOutputWithPastAndCrossAttentions,
  33. BaseModelOutputWithPoolingAndCrossAttentions,
  34. CausalLMOutputWithCrossAttentions,
  35. MaskedLMOutput,
  36. MultipleChoiceModelOutput,
  37. QuestionAnsweringModelOutput,
  38. SequenceClassifierOutput,
  39. TokenClassifierOutput,
  40. )
  41. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  42. from ...processing_utils import Unpack
  43. from ...pytorch_utils import apply_chunking_to_forward
  44. from ...utils import TransformersKwargs, auto_docstring, logging
  45. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  46. from ...utils.output_capturing import capture_outputs
  47. from .configuration_camembert import CamembertConfig
  48. logger = logging.get_logger(__name__)
  49. class CamembertEmbeddings(nn.Module):
  50. """Construct the embeddings from word, position and token_type embeddings."""
  51. def __init__(self, config):
  52. super().__init__()
  53. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  54. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  55. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  56. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  57. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  58. self.register_buffer(
  59. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  60. )
  61. self.register_buffer(
  62. "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
  63. )
  64. self.padding_idx = config.pad_token_id
  65. self.position_embeddings = nn.Embedding(
  66. config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
  67. )
  68. def forward(
  69. self,
  70. input_ids: torch.LongTensor | None = None,
  71. token_type_ids: torch.LongTensor | None = None,
  72. position_ids: torch.LongTensor | None = None,
  73. inputs_embeds: torch.FloatTensor | None = None,
  74. past_key_values_length: int = 0,
  75. ) -> torch.Tensor:
  76. if position_ids is None:
  77. if input_ids is not None:
  78. # Create the position ids from the input token ids. Any padded tokens remain padded.
  79. position_ids = self.create_position_ids_from_input_ids(
  80. input_ids, self.padding_idx, past_key_values_length
  81. )
  82. else:
  83. position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, self.padding_idx)
  84. if input_ids is not None:
  85. input_shape = input_ids.size()
  86. else:
  87. input_shape = inputs_embeds.size()[:-1]
  88. batch_size, seq_length = input_shape
  89. # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
  90. # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
  91. # issue #5664
  92. if token_type_ids is None:
  93. if hasattr(self, "token_type_ids"):
  94. # NOTE: We assume either pos ids to have bsz == 1 (broadcastable) or bsz == effective bsz (input_shape[0])
  95. buffered_token_type_ids = self.token_type_ids.expand(position_ids.shape[0], -1)
  96. buffered_token_type_ids = torch.gather(buffered_token_type_ids, dim=1, index=position_ids)
  97. token_type_ids = buffered_token_type_ids.expand(batch_size, seq_length)
  98. else:
  99. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  100. if inputs_embeds is None:
  101. inputs_embeds = self.word_embeddings(input_ids)
  102. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  103. embeddings = inputs_embeds + token_type_embeddings
  104. position_embeddings = self.position_embeddings(position_ids)
  105. embeddings = embeddings + position_embeddings
  106. embeddings = self.LayerNorm(embeddings)
  107. embeddings = self.dropout(embeddings)
  108. return embeddings
  109. @staticmethod
  110. def create_position_ids_from_inputs_embeds(inputs_embeds, padding_idx):
  111. """
  112. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
  113. Args:
  114. inputs_embeds: torch.Tensor
  115. Returns: torch.Tensor
  116. """
  117. input_shape = inputs_embeds.size()[:-1]
  118. sequence_length = input_shape[1]
  119. position_ids = torch.arange(
  120. padding_idx + 1, sequence_length + padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
  121. )
  122. return position_ids.unsqueeze(0).expand(input_shape)
  123. @staticmethod
  124. def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
  125. """
  126. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
  127. are ignored. This is modified from fairseq's `utils.make_positions`.
  128. Args:
  129. x: torch.Tensor x:
  130. Returns: torch.Tensor
  131. """
  132. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  133. mask = input_ids.ne(padding_idx).int()
  134. incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
  135. return incremental_indices.long() + padding_idx
  136. def eager_attention_forward(
  137. module: nn.Module,
  138. query: torch.Tensor,
  139. key: torch.Tensor,
  140. value: torch.Tensor,
  141. attention_mask: torch.Tensor | None,
  142. scaling: float | None = None,
  143. dropout: float = 0.0,
  144. **kwargs: Unpack[TransformersKwargs],
  145. ):
  146. if scaling is None:
  147. scaling = query.size(-1) ** -0.5
  148. # Take the dot product between "query" and "key" to get the raw attention scores.
  149. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  150. if attention_mask is not None:
  151. attn_weights = attn_weights + attention_mask
  152. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  153. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  154. attn_output = torch.matmul(attn_weights, value)
  155. attn_output = attn_output.transpose(1, 2).contiguous()
  156. return attn_output, attn_weights
  157. class CamembertSelfAttention(nn.Module):
  158. def __init__(self, config, is_causal=False, layer_idx=None):
  159. super().__init__()
  160. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  161. raise ValueError(
  162. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  163. f"heads ({config.num_attention_heads})"
  164. )
  165. self.config = config
  166. self.num_attention_heads = config.num_attention_heads
  167. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  168. self.all_head_size = self.num_attention_heads * self.attention_head_size
  169. self.scaling = self.attention_head_size**-0.5
  170. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  171. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  172. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  173. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  174. self.is_decoder = config.is_decoder
  175. self.is_causal = is_causal
  176. self.layer_idx = layer_idx
  177. def forward(
  178. self,
  179. hidden_states: torch.Tensor,
  180. attention_mask: torch.FloatTensor | None = None,
  181. past_key_values: Cache | None = None,
  182. **kwargs: Unpack[TransformersKwargs],
  183. ) -> tuple[torch.Tensor]:
  184. input_shape = hidden_states.shape[:-1]
  185. hidden_shape = (*input_shape, -1, self.attention_head_size)
  186. # get all proj
  187. query_layer = self.query(hidden_states).view(*hidden_shape).transpose(1, 2)
  188. key_layer = self.key(hidden_states).view(*hidden_shape).transpose(1, 2)
  189. value_layer = self.value(hidden_states).view(*hidden_shape).transpose(1, 2)
  190. if past_key_values is not None:
  191. # decoder-only camembert can have a simple dynamic cache for example
  192. current_past_key_values = past_key_values
  193. if isinstance(past_key_values, EncoderDecoderCache):
  194. current_past_key_values = past_key_values.self_attention_cache
  195. # save all key/value_layer to cache to be re-used for fast auto-regressive generation
  196. key_layer, value_layer = current_past_key_values.update(key_layer, value_layer, self.layer_idx)
  197. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  198. self.config._attn_implementation, eager_attention_forward
  199. )
  200. attn_output, attn_weights = attention_interface(
  201. self,
  202. query_layer,
  203. key_layer,
  204. value_layer,
  205. attention_mask,
  206. dropout=0.0 if not self.training else self.dropout.p,
  207. scaling=self.scaling,
  208. **kwargs,
  209. )
  210. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  211. return attn_output, attn_weights
  212. class CamembertCrossAttention(nn.Module):
  213. def __init__(self, config, is_causal=False, layer_idx=None):
  214. super().__init__()
  215. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  216. raise ValueError(
  217. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  218. f"heads ({config.num_attention_heads})"
  219. )
  220. self.config = config
  221. self.num_attention_heads = config.num_attention_heads
  222. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  223. self.all_head_size = self.num_attention_heads * self.attention_head_size
  224. self.scaling = self.attention_head_size**-0.5
  225. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  226. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  227. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  228. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  229. self.is_causal = is_causal
  230. self.layer_idx = layer_idx
  231. def forward(
  232. self,
  233. hidden_states: torch.Tensor,
  234. encoder_hidden_states: torch.FloatTensor | None = None,
  235. attention_mask: torch.FloatTensor | None = None,
  236. past_key_values: EncoderDecoderCache | None = None,
  237. **kwargs: Unpack[TransformersKwargs],
  238. ) -> tuple[torch.Tensor]:
  239. # determine input shapes
  240. input_shape = hidden_states.shape[:-1]
  241. hidden_shape = (*input_shape, -1, self.attention_head_size)
  242. # get query proj
  243. query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  244. is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
  245. if past_key_values is not None and is_updated:
  246. # reuse k,v, cross_attentions
  247. key_layer = past_key_values.cross_attention_cache.layers[self.layer_idx].keys
  248. value_layer = past_key_values.cross_attention_cache.layers[self.layer_idx].values
  249. else:
  250. kv_shape = (*encoder_hidden_states.shape[:-1], -1, self.attention_head_size)
  251. key_layer = self.key(encoder_hidden_states).view(kv_shape).transpose(1, 2)
  252. value_layer = self.value(encoder_hidden_states).view(kv_shape).transpose(1, 2)
  253. if past_key_values is not None:
  254. # save all states to the cache
  255. key_layer, value_layer = past_key_values.cross_attention_cache.update(
  256. key_layer, value_layer, self.layer_idx
  257. )
  258. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  259. past_key_values.is_updated[self.layer_idx] = True
  260. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  261. self.config._attn_implementation, eager_attention_forward
  262. )
  263. attn_output, attn_weights = attention_interface(
  264. self,
  265. query_layer,
  266. key_layer,
  267. value_layer,
  268. attention_mask,
  269. dropout=0.0 if not self.training else self.dropout.p,
  270. scaling=self.scaling,
  271. **kwargs,
  272. )
  273. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  274. return attn_output, attn_weights
  275. class CamembertSelfOutput(nn.Module):
  276. def __init__(self, config):
  277. super().__init__()
  278. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  279. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  280. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  281. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  282. hidden_states = self.dense(hidden_states)
  283. hidden_states = self.dropout(hidden_states)
  284. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  285. return hidden_states
  286. class CamembertAttention(nn.Module):
  287. def __init__(self, config, is_causal=False, layer_idx=None, is_cross_attention=False):
  288. super().__init__()
  289. self.is_cross_attention = is_cross_attention
  290. attention_class = CamembertCrossAttention if is_cross_attention else CamembertSelfAttention
  291. self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx)
  292. self.output = CamembertSelfOutput(config)
  293. def forward(
  294. self,
  295. hidden_states: torch.Tensor,
  296. attention_mask: torch.FloatTensor | None = None,
  297. encoder_hidden_states: torch.FloatTensor | None = None,
  298. encoder_attention_mask: torch.FloatTensor | None = None,
  299. past_key_values: Cache | None = None,
  300. **kwargs: Unpack[TransformersKwargs],
  301. ) -> tuple[torch.Tensor]:
  302. attention_mask = attention_mask if not self.is_cross_attention else encoder_attention_mask
  303. attention_output, attn_weights = self.self(
  304. hidden_states,
  305. encoder_hidden_states=encoder_hidden_states,
  306. attention_mask=attention_mask,
  307. past_key_values=past_key_values,
  308. **kwargs,
  309. )
  310. attention_output = self.output(attention_output, hidden_states)
  311. return attention_output, attn_weights
  312. class CamembertIntermediate(nn.Module):
  313. def __init__(self, config):
  314. super().__init__()
  315. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  316. if isinstance(config.hidden_act, str):
  317. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  318. else:
  319. self.intermediate_act_fn = config.hidden_act
  320. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  321. hidden_states = self.dense(hidden_states)
  322. hidden_states = self.intermediate_act_fn(hidden_states)
  323. return hidden_states
  324. class CamembertOutput(nn.Module):
  325. def __init__(self, config):
  326. super().__init__()
  327. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  328. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  329. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  330. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  331. hidden_states = self.dense(hidden_states)
  332. hidden_states = self.dropout(hidden_states)
  333. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  334. return hidden_states
  335. class CamembertLayer(GradientCheckpointingLayer):
  336. def __init__(self, config, layer_idx=None):
  337. super().__init__()
  338. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  339. self.seq_len_dim = 1
  340. self.attention = CamembertAttention(config, is_causal=config.is_decoder, layer_idx=layer_idx)
  341. self.is_decoder = config.is_decoder
  342. self.add_cross_attention = config.add_cross_attention
  343. if self.add_cross_attention:
  344. if not self.is_decoder:
  345. raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
  346. self.crossattention = CamembertAttention(
  347. config,
  348. is_causal=False,
  349. layer_idx=layer_idx,
  350. is_cross_attention=True,
  351. )
  352. self.intermediate = CamembertIntermediate(config)
  353. self.output = CamembertOutput(config)
  354. def forward(
  355. self,
  356. hidden_states: torch.Tensor,
  357. attention_mask: torch.FloatTensor | None = None,
  358. encoder_hidden_states: torch.FloatTensor | None = None,
  359. encoder_attention_mask: torch.FloatTensor | None = None,
  360. past_key_values: Cache | None = None,
  361. **kwargs: Unpack[TransformersKwargs],
  362. ) -> torch.Tensor:
  363. self_attention_output, _ = self.attention(
  364. hidden_states,
  365. attention_mask,
  366. past_key_values=past_key_values,
  367. **kwargs,
  368. )
  369. attention_output = self_attention_output
  370. if self.is_decoder and encoder_hidden_states is not None:
  371. if not hasattr(self, "crossattention"):
  372. raise ValueError(
  373. f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
  374. " by setting `config.add_cross_attention=True`"
  375. )
  376. cross_attention_output, _ = self.crossattention(
  377. self_attention_output,
  378. None, # attention_mask
  379. encoder_hidden_states,
  380. encoder_attention_mask,
  381. past_key_values=past_key_values,
  382. **kwargs,
  383. )
  384. attention_output = cross_attention_output
  385. layer_output = apply_chunking_to_forward(
  386. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  387. )
  388. return layer_output
  389. def feed_forward_chunk(self, attention_output):
  390. intermediate_output = self.intermediate(attention_output)
  391. layer_output = self.output(intermediate_output, attention_output)
  392. return layer_output
  393. class CamembertLMHead(nn.Module):
  394. """Camembert Head for masked language modeling."""
  395. def __init__(self, config):
  396. super().__init__()
  397. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  398. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  399. self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
  400. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  401. def forward(self, features, **kwargs):
  402. x = self.dense(features)
  403. x = gelu(x)
  404. x = self.layer_norm(x)
  405. # project back to size of vocabulary with bias
  406. x = self.decoder(x)
  407. return x
  408. @auto_docstring
  409. class CamembertPreTrainedModel(PreTrainedModel):
  410. config_class = CamembertConfig
  411. base_model_prefix = "roberta"
  412. supports_gradient_checkpointing = True
  413. _supports_flash_attn = True
  414. _supports_sdpa = True
  415. _supports_flex_attn = True
  416. _supports_attention_backend = True
  417. _can_record_outputs = {
  418. "hidden_states": CamembertLayer,
  419. "attentions": CamembertSelfAttention,
  420. "cross_attentions": CamembertCrossAttention,
  421. }
  422. @torch.no_grad()
  423. def _init_weights(self, module):
  424. """Initialize the weights"""
  425. super()._init_weights(module)
  426. if isinstance(module, CamembertLMHead):
  427. init.zeros_(module.bias)
  428. elif isinstance(module, CamembertEmbeddings):
  429. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  430. init.zeros_(module.token_type_ids)
  431. class CamembertEncoder(nn.Module):
  432. def __init__(self, config):
  433. super().__init__()
  434. self.config = config
  435. self.layer = nn.ModuleList([CamembertLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
  436. def forward(
  437. self,
  438. hidden_states: torch.Tensor,
  439. attention_mask: torch.FloatTensor | None = None,
  440. encoder_hidden_states: torch.FloatTensor | None = None,
  441. encoder_attention_mask: torch.FloatTensor | None = None,
  442. past_key_values: Cache | None = None,
  443. use_cache: bool | None = None,
  444. **kwargs: Unpack[TransformersKwargs],
  445. ) -> tuple[torch.Tensor] | BaseModelOutputWithPastAndCrossAttentions:
  446. for i, layer_module in enumerate(self.layer):
  447. hidden_states = layer_module(
  448. hidden_states,
  449. attention_mask,
  450. encoder_hidden_states, # as a positional argument for gradient checkpointing
  451. encoder_attention_mask=encoder_attention_mask,
  452. past_key_values=past_key_values,
  453. **kwargs,
  454. )
  455. return BaseModelOutputWithPastAndCrossAttentions(
  456. last_hidden_state=hidden_states,
  457. past_key_values=past_key_values if use_cache else None,
  458. )
  459. class CamembertPooler(nn.Module):
  460. def __init__(self, config):
  461. super().__init__()
  462. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  463. self.activation = nn.Tanh()
  464. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  465. # We "pool" the model by simply taking the hidden state corresponding
  466. # to the first token.
  467. first_token_tensor = hidden_states[:, 0]
  468. pooled_output = self.dense(first_token_tensor)
  469. pooled_output = self.activation(pooled_output)
  470. return pooled_output
  471. @auto_docstring(
  472. custom_intro="""
  473. The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
  474. cross-attention is added between the self-attention layers, following the architecture described in [Attention is
  475. all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
  476. Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
  477. To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
  478. to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
  479. `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
  480. """
  481. )
  482. class CamembertModel(CamembertPreTrainedModel):
  483. _no_split_modules = ["CamembertEmbeddings", "CamembertLayer"]
  484. def __init__(self, config, add_pooling_layer=True):
  485. r"""
  486. add_pooling_layer (bool, *optional*, defaults to `True`):
  487. Whether to add a pooling layer
  488. """
  489. super().__init__(config)
  490. self.config = config
  491. self.gradient_checkpointing = False
  492. self.embeddings = CamembertEmbeddings(config)
  493. self.encoder = CamembertEncoder(config)
  494. self.pooler = CamembertPooler(config) if add_pooling_layer else None
  495. # Initialize weights and apply final processing
  496. self.post_init()
  497. def get_input_embeddings(self):
  498. return self.embeddings.word_embeddings
  499. def set_input_embeddings(self, value):
  500. self.embeddings.word_embeddings = value
  501. @merge_with_config_defaults
  502. @capture_outputs
  503. @auto_docstring
  504. def forward(
  505. self,
  506. input_ids: torch.Tensor | None = None,
  507. attention_mask: torch.Tensor | None = None,
  508. token_type_ids: torch.Tensor | None = None,
  509. position_ids: torch.Tensor | None = None,
  510. inputs_embeds: torch.Tensor | None = None,
  511. encoder_hidden_states: torch.Tensor | None = None,
  512. encoder_attention_mask: torch.Tensor | None = None,
  513. past_key_values: Cache | None = None,
  514. use_cache: bool | None = None,
  515. **kwargs: Unpack[TransformersKwargs],
  516. ) -> tuple[torch.Tensor] | BaseModelOutputWithPoolingAndCrossAttentions:
  517. if (input_ids is None) ^ (inputs_embeds is not None):
  518. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  519. if self.config.is_decoder:
  520. use_cache = use_cache if use_cache is not None else self.config.use_cache
  521. else:
  522. use_cache = False
  523. if use_cache and past_key_values is None:
  524. past_key_values = (
  525. EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  526. if encoder_hidden_states is not None or self.config.is_encoder_decoder
  527. else DynamicCache(config=self.config)
  528. )
  529. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  530. embedding_output = self.embeddings(
  531. input_ids=input_ids,
  532. position_ids=position_ids,
  533. token_type_ids=token_type_ids,
  534. inputs_embeds=inputs_embeds,
  535. past_key_values_length=past_key_values_length,
  536. )
  537. attention_mask, encoder_attention_mask = self._create_attention_masks(
  538. attention_mask=attention_mask,
  539. encoder_attention_mask=encoder_attention_mask,
  540. embedding_output=embedding_output,
  541. encoder_hidden_states=encoder_hidden_states,
  542. past_key_values=past_key_values,
  543. )
  544. encoder_outputs = self.encoder(
  545. embedding_output,
  546. attention_mask=attention_mask,
  547. encoder_hidden_states=encoder_hidden_states,
  548. encoder_attention_mask=encoder_attention_mask,
  549. past_key_values=past_key_values,
  550. use_cache=use_cache,
  551. position_ids=position_ids,
  552. **kwargs,
  553. )
  554. sequence_output = encoder_outputs.last_hidden_state
  555. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  556. return BaseModelOutputWithPoolingAndCrossAttentions(
  557. last_hidden_state=sequence_output,
  558. pooler_output=pooled_output,
  559. past_key_values=encoder_outputs.past_key_values,
  560. )
  561. def _create_attention_masks(
  562. self,
  563. attention_mask,
  564. encoder_attention_mask,
  565. embedding_output,
  566. encoder_hidden_states,
  567. past_key_values,
  568. ):
  569. if self.config.is_decoder:
  570. attention_mask = create_causal_mask(
  571. config=self.config,
  572. inputs_embeds=embedding_output,
  573. attention_mask=attention_mask,
  574. past_key_values=past_key_values,
  575. )
  576. else:
  577. attention_mask = create_bidirectional_mask(
  578. config=self.config,
  579. inputs_embeds=embedding_output,
  580. attention_mask=attention_mask,
  581. )
  582. if encoder_attention_mask is not None:
  583. encoder_attention_mask = create_bidirectional_mask(
  584. config=self.config,
  585. inputs_embeds=embedding_output,
  586. attention_mask=encoder_attention_mask,
  587. encoder_hidden_states=encoder_hidden_states,
  588. )
  589. return attention_mask, encoder_attention_mask
  590. @auto_docstring
  591. class CamembertForMaskedLM(CamembertPreTrainedModel):
  592. _tied_weights_keys = {
  593. "lm_head.decoder.weight": "roberta.embeddings.word_embeddings.weight",
  594. "lm_head.decoder.bias": "lm_head.bias",
  595. }
  596. def __init__(self, config):
  597. super().__init__(config)
  598. if config.is_decoder:
  599. logger.warning(
  600. "If you want to use `CamembertForMaskedLM` make sure `config.is_decoder=False` for "
  601. "bi-directional self-attention."
  602. )
  603. self.lm_head = CamembertLMHead(config)
  604. self.roberta = CamembertModel(config, add_pooling_layer=False)
  605. # Initialize weights and apply final processing
  606. self.post_init()
  607. def get_output_embeddings(self):
  608. return self.lm_head.decoder
  609. def set_output_embeddings(self, new_embeddings):
  610. self.lm_head.decoder = new_embeddings
  611. @can_return_tuple
  612. @auto_docstring
  613. def forward(
  614. self,
  615. input_ids: torch.LongTensor | None = None,
  616. attention_mask: torch.FloatTensor | None = None,
  617. token_type_ids: torch.LongTensor | None = None,
  618. position_ids: torch.LongTensor | None = None,
  619. inputs_embeds: torch.FloatTensor | None = None,
  620. encoder_hidden_states: torch.FloatTensor | None = None,
  621. encoder_attention_mask: torch.FloatTensor | None = None,
  622. labels: torch.LongTensor | None = None,
  623. **kwargs: Unpack[TransformersKwargs],
  624. ) -> tuple[torch.Tensor] | MaskedLMOutput:
  625. r"""
  626. token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  627. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
  628. - 0 corresponds to a *sentence A* token,
  629. - 1 corresponds to a *sentence B* token.
  630. This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
  631. >= 2. All the value in this tensor should be always < type_vocab_size.
  632. [What are token type IDs?](../glossary#token-type-ids)
  633. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  634. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  635. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  636. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  637. """
  638. outputs = self.roberta(
  639. input_ids,
  640. attention_mask=attention_mask,
  641. token_type_ids=token_type_ids,
  642. position_ids=position_ids,
  643. inputs_embeds=inputs_embeds,
  644. encoder_hidden_states=encoder_hidden_states,
  645. encoder_attention_mask=encoder_attention_mask,
  646. return_dict=True,
  647. **kwargs,
  648. )
  649. sequence_output = outputs[0]
  650. prediction_scores = self.lm_head(sequence_output)
  651. masked_lm_loss = None
  652. if labels is not None:
  653. # move labels to correct device
  654. labels = labels.to(prediction_scores.device)
  655. loss_fct = CrossEntropyLoss()
  656. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  657. return MaskedLMOutput(
  658. loss=masked_lm_loss,
  659. logits=prediction_scores,
  660. hidden_states=outputs.hidden_states,
  661. attentions=outputs.attentions,
  662. )
  663. class CamembertClassificationHead(nn.Module):
  664. """Head for sentence-level classification tasks."""
  665. def __init__(self, config):
  666. super().__init__()
  667. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  668. classifier_dropout = (
  669. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  670. )
  671. self.dropout = nn.Dropout(classifier_dropout)
  672. self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
  673. def forward(self, features, **kwargs):
  674. x = features[:, 0, :] # take <s> token (equiv. to [CLS])
  675. x = self.dropout(x)
  676. x = self.dense(x)
  677. x = torch.tanh(x)
  678. x = self.dropout(x)
  679. x = self.out_proj(x)
  680. return x
  681. @auto_docstring(
  682. custom_intro="""
  683. Camembert Model transformer with a sequence classification/regression head on top (a linear layer on top of the
  684. pooled output) e.g. for GLUE tasks.
  685. """
  686. )
  687. class CamembertForSequenceClassification(CamembertPreTrainedModel):
  688. def __init__(self, config):
  689. super().__init__(config)
  690. self.num_labels = config.num_labels
  691. self.config = config
  692. self.classifier = CamembertClassificationHead(config)
  693. self.roberta = CamembertModel(config, add_pooling_layer=False)
  694. # Initialize weights and apply final processing
  695. self.post_init()
  696. @can_return_tuple
  697. @auto_docstring
  698. def forward(
  699. self,
  700. input_ids: torch.LongTensor | None = None,
  701. attention_mask: torch.FloatTensor | None = None,
  702. token_type_ids: torch.LongTensor | None = None,
  703. position_ids: torch.LongTensor | None = None,
  704. inputs_embeds: torch.FloatTensor | None = None,
  705. labels: torch.LongTensor | None = None,
  706. **kwargs: Unpack[TransformersKwargs],
  707. ) -> tuple[torch.Tensor] | SequenceClassifierOutput:
  708. r"""
  709. token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  710. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
  711. - 0 corresponds to a *sentence A* token,
  712. - 1 corresponds to a *sentence B* token.
  713. This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
  714. >= 2. All the value in this tensor should be always < type_vocab_size.
  715. [What are token type IDs?](../glossary#token-type-ids)
  716. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  717. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  718. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  719. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  720. """
  721. outputs = self.roberta(
  722. input_ids,
  723. attention_mask=attention_mask,
  724. token_type_ids=token_type_ids,
  725. position_ids=position_ids,
  726. inputs_embeds=inputs_embeds,
  727. return_dict=True,
  728. **kwargs,
  729. )
  730. sequence_output = outputs[0]
  731. logits = self.classifier(sequence_output)
  732. loss = None
  733. if labels is not None:
  734. # move labels to correct device
  735. labels = labels.to(logits.device)
  736. if self.config.problem_type is None:
  737. if self.num_labels == 1:
  738. self.config.problem_type = "regression"
  739. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  740. self.config.problem_type = "single_label_classification"
  741. else:
  742. self.config.problem_type = "multi_label_classification"
  743. if self.config.problem_type == "regression":
  744. loss_fct = MSELoss()
  745. if self.num_labels == 1:
  746. loss = loss_fct(logits.squeeze(), labels.squeeze())
  747. else:
  748. loss = loss_fct(logits, labels)
  749. elif self.config.problem_type == "single_label_classification":
  750. loss_fct = CrossEntropyLoss()
  751. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  752. elif self.config.problem_type == "multi_label_classification":
  753. loss_fct = BCEWithLogitsLoss()
  754. loss = loss_fct(logits, labels)
  755. return SequenceClassifierOutput(
  756. loss=loss,
  757. logits=logits,
  758. hidden_states=outputs.hidden_states,
  759. attentions=outputs.attentions,
  760. )
  761. @auto_docstring
  762. class CamembertForMultipleChoice(CamembertPreTrainedModel):
  763. def __init__(self, config):
  764. super().__init__(config)
  765. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  766. self.classifier = nn.Linear(config.hidden_size, 1)
  767. self.roberta = CamembertModel(config, add_pooling_layer=False)
  768. # Initialize weights and apply final processing
  769. self.post_init()
  770. @can_return_tuple
  771. @auto_docstring
  772. def forward(
  773. self,
  774. input_ids: torch.LongTensor | None = None,
  775. token_type_ids: torch.LongTensor | None = None,
  776. attention_mask: torch.FloatTensor | None = None,
  777. labels: torch.LongTensor | None = None,
  778. position_ids: torch.LongTensor | None = None,
  779. inputs_embeds: torch.FloatTensor | None = None,
  780. **kwargs: Unpack[TransformersKwargs],
  781. ) -> tuple[torch.Tensor] | MultipleChoiceModelOutput:
  782. r"""
  783. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  784. Indices of input sequence tokens in the vocabulary.
  785. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  786. [`PreTrainedTokenizer.__call__`] for details.
  787. [What are input IDs?](../glossary#input-ids)
  788. token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  789. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
  790. - 0 corresponds to a *sentence A* token,
  791. - 1 corresponds to a *sentence B* token.
  792. This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
  793. >= 2. All the value in this tensor should be always < type_vocab_size.
  794. [What are token type IDs?](../glossary#token-type-ids)
  795. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  796. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  797. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  798. `input_ids` above)
  799. position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  800. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  801. config.max_position_embeddings - 1]`.
  802. [What are position IDs?](../glossary#position-ids)
  803. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  804. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  805. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  806. model's internal embedding lookup matrix.
  807. """
  808. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  809. flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  810. flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  811. flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  812. flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  813. flat_inputs_embeds = (
  814. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  815. if inputs_embeds is not None
  816. else None
  817. )
  818. outputs = self.roberta(
  819. flat_input_ids,
  820. position_ids=flat_position_ids,
  821. token_type_ids=flat_token_type_ids,
  822. attention_mask=flat_attention_mask,
  823. inputs_embeds=flat_inputs_embeds,
  824. return_dict=True,
  825. **kwargs,
  826. )
  827. pooled_output = outputs[1]
  828. pooled_output = self.dropout(pooled_output)
  829. logits = self.classifier(pooled_output)
  830. reshaped_logits = logits.view(-1, num_choices)
  831. loss = None
  832. if labels is not None:
  833. # move labels to correct device
  834. labels = labels.to(reshaped_logits.device)
  835. loss_fct = CrossEntropyLoss()
  836. loss = loss_fct(reshaped_logits, labels)
  837. return MultipleChoiceModelOutput(
  838. loss=loss,
  839. logits=reshaped_logits,
  840. hidden_states=outputs.hidden_states,
  841. attentions=outputs.attentions,
  842. )
  843. @auto_docstring
  844. class CamembertForTokenClassification(CamembertPreTrainedModel):
  845. def __init__(self, config):
  846. super().__init__(config)
  847. self.num_labels = config.num_labels
  848. classifier_dropout = (
  849. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  850. )
  851. self.dropout = nn.Dropout(classifier_dropout)
  852. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  853. self.roberta = CamembertModel(config, add_pooling_layer=False)
  854. # Initialize weights and apply final processing
  855. self.post_init()
  856. @can_return_tuple
  857. @auto_docstring
  858. def forward(
  859. self,
  860. input_ids: torch.LongTensor | None = None,
  861. attention_mask: torch.FloatTensor | None = None,
  862. token_type_ids: torch.LongTensor | None = None,
  863. position_ids: torch.LongTensor | None = None,
  864. inputs_embeds: torch.FloatTensor | None = None,
  865. labels: torch.LongTensor | None = None,
  866. **kwargs: Unpack[TransformersKwargs],
  867. ) -> tuple[torch.Tensor] | TokenClassifierOutput:
  868. r"""
  869. token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  870. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
  871. - 0 corresponds to a *sentence A* token,
  872. - 1 corresponds to a *sentence B* token.
  873. This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
  874. >= 2. All the value in this tensor should be always < type_vocab_size.
  875. [What are token type IDs?](../glossary#token-type-ids)
  876. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  877. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  878. """
  879. outputs = self.roberta(
  880. input_ids,
  881. attention_mask=attention_mask,
  882. token_type_ids=token_type_ids,
  883. position_ids=position_ids,
  884. inputs_embeds=inputs_embeds,
  885. return_dict=True,
  886. **kwargs,
  887. )
  888. sequence_output = outputs[0]
  889. sequence_output = self.dropout(sequence_output)
  890. logits = self.classifier(sequence_output)
  891. loss = None
  892. if labels is not None:
  893. # move labels to correct device
  894. labels = labels.to(logits.device)
  895. loss_fct = CrossEntropyLoss()
  896. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  897. return TokenClassifierOutput(
  898. loss=loss,
  899. logits=logits,
  900. hidden_states=outputs.hidden_states,
  901. attentions=outputs.attentions,
  902. )
  903. @auto_docstring
  904. class CamembertForQuestionAnswering(CamembertPreTrainedModel):
  905. def __init__(self, config):
  906. super().__init__(config)
  907. self.num_labels = config.num_labels
  908. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  909. self.roberta = CamembertModel(config, add_pooling_layer=False)
  910. # Initialize weights and apply final processing
  911. self.post_init()
  912. @can_return_tuple
  913. @auto_docstring
  914. def forward(
  915. self,
  916. input_ids: torch.LongTensor | None = None,
  917. attention_mask: torch.FloatTensor | None = None,
  918. token_type_ids: torch.LongTensor | None = None,
  919. position_ids: torch.LongTensor | None = None,
  920. inputs_embeds: torch.FloatTensor | None = None,
  921. start_positions: torch.LongTensor | None = None,
  922. end_positions: torch.LongTensor | None = None,
  923. **kwargs: Unpack[TransformersKwargs],
  924. ) -> tuple[torch.Tensor] | QuestionAnsweringModelOutput:
  925. r"""
  926. token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  927. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
  928. - 0 corresponds to a *sentence A* token,
  929. - 1 corresponds to a *sentence B* token.
  930. This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
  931. >= 2. All the value in this tensor should be always < type_vocab_size.
  932. [What are token type IDs?](../glossary#token-type-ids)
  933. """
  934. outputs = self.roberta(
  935. input_ids,
  936. attention_mask=attention_mask,
  937. token_type_ids=token_type_ids,
  938. position_ids=position_ids,
  939. inputs_embeds=inputs_embeds,
  940. return_dict=True,
  941. **kwargs,
  942. )
  943. sequence_output = outputs[0]
  944. logits = self.qa_outputs(sequence_output)
  945. start_logits, end_logits = logits.split(1, dim=-1)
  946. start_logits = start_logits.squeeze(-1).contiguous()
  947. end_logits = end_logits.squeeze(-1).contiguous()
  948. total_loss = None
  949. if start_positions is not None and end_positions is not None:
  950. # If we are on multi-GPU, split add a dimension
  951. if len(start_positions.size()) > 1:
  952. start_positions = start_positions.squeeze(-1)
  953. if len(end_positions.size()) > 1:
  954. end_positions = end_positions.squeeze(-1)
  955. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  956. ignored_index = start_logits.size(1)
  957. start_positions = start_positions.clamp(0, ignored_index)
  958. end_positions = end_positions.clamp(0, ignored_index)
  959. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  960. start_loss = loss_fct(start_logits, start_positions)
  961. end_loss = loss_fct(end_logits, end_positions)
  962. total_loss = (start_loss + end_loss) / 2
  963. return QuestionAnsweringModelOutput(
  964. loss=total_loss,
  965. start_logits=start_logits,
  966. end_logits=end_logits,
  967. hidden_states=outputs.hidden_states,
  968. attentions=outputs.attentions,
  969. )
  970. @auto_docstring(
  971. custom_intro="""
  972. Camembert Model with a `language modeling` head on top for CLM fine-tuning.
  973. """
  974. )
  975. class CamembertForCausalLM(CamembertPreTrainedModel, GenerationMixin):
  976. _tied_weights_keys = {
  977. "lm_head.decoder.weight": "roberta.embeddings.word_embeddings.weight",
  978. "lm_head.decoder.bias": "lm_head.bias",
  979. }
  980. def __init__(self, config):
  981. super().__init__(config)
  982. if not config.is_decoder:
  983. logger.warning("If you want to use `CamembertLMHeadModel` as a standalone, add `is_decoder=True.`")
  984. self.lm_head = CamembertLMHead(config)
  985. self.roberta = CamembertModel(config, add_pooling_layer=False)
  986. # Initialize weights and apply final processing
  987. self.post_init()
  988. def get_output_embeddings(self):
  989. return self.lm_head.decoder
  990. def set_output_embeddings(self, new_embeddings):
  991. self.lm_head.decoder = new_embeddings
  992. @can_return_tuple
  993. @auto_docstring
  994. def forward(
  995. self,
  996. input_ids: torch.LongTensor | None = None,
  997. attention_mask: torch.FloatTensor | None = None,
  998. token_type_ids: torch.LongTensor | None = None,
  999. position_ids: torch.LongTensor | None = None,
  1000. inputs_embeds: torch.FloatTensor | None = None,
  1001. encoder_hidden_states: torch.FloatTensor | None = None,
  1002. encoder_attention_mask: torch.FloatTensor | None = None,
  1003. labels: torch.LongTensor | None = None,
  1004. past_key_values: tuple[tuple[torch.FloatTensor]] | None = None,
  1005. use_cache: bool | None = None,
  1006. logits_to_keep: int | torch.Tensor = 0,
  1007. **kwargs: Unpack[TransformersKwargs],
  1008. ) -> tuple[torch.Tensor] | CausalLMOutputWithCrossAttentions:
  1009. r"""
  1010. token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1011. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
  1012. - 0 corresponds to a *sentence A* token,
  1013. - 1 corresponds to a *sentence B* token.
  1014. This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
  1015. >= 2. All the value in this tensor should be always < type_vocab_size.
  1016. [What are token type IDs?](../glossary#token-type-ids)
  1017. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1018. Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
  1019. `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
  1020. ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  1021. Example:
  1022. ```python
  1023. >>> from transformers import AutoTokenizer, CamembertForCausalLM, AutoConfig
  1024. >>> import torch
  1025. >>> tokenizer = AutoTokenizer.from_pretrained("almanach/camembert-base")
  1026. >>> config = AutoConfig.from_pretrained("almanach/camembert-base")
  1027. >>> config.is_decoder = True
  1028. >>> model = CamembertForCausalLM.from_pretrained("almanach/camembert-base", config=config)
  1029. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  1030. >>> outputs = model(**inputs)
  1031. >>> prediction_logits = outputs.logits
  1032. ```"""
  1033. if labels is not None:
  1034. use_cache = False
  1035. outputs: BaseModelOutputWithPoolingAndCrossAttentions = self.roberta(
  1036. input_ids,
  1037. attention_mask=attention_mask,
  1038. token_type_ids=token_type_ids,
  1039. position_ids=position_ids,
  1040. inputs_embeds=inputs_embeds,
  1041. encoder_hidden_states=encoder_hidden_states,
  1042. encoder_attention_mask=encoder_attention_mask,
  1043. past_key_values=past_key_values,
  1044. use_cache=use_cache,
  1045. return_dict=True,
  1046. **kwargs,
  1047. )
  1048. hidden_states = outputs.last_hidden_state
  1049. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  1050. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  1051. logits = self.lm_head(hidden_states[:, slice_indices, :])
  1052. loss = None
  1053. if labels is not None:
  1054. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  1055. return CausalLMOutputWithCrossAttentions(
  1056. loss=loss,
  1057. logits=logits,
  1058. past_key_values=outputs.past_key_values,
  1059. hidden_states=outputs.hidden_states,
  1060. attentions=outputs.attentions,
  1061. cross_attentions=outputs.cross_attentions,
  1062. )
  1063. __all__ = [
  1064. "CamembertForCausalLM",
  1065. "CamembertForMaskedLM",
  1066. "CamembertForMultipleChoice",
  1067. "CamembertForQuestionAnswering",
  1068. "CamembertForSequenceClassification",
  1069. "CamembertForTokenClassification",
  1070. "CamembertModel",
  1071. "CamembertPreTrainedModel",
  1072. ]