modeling_xmod.py 58 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390
  1. # Copyright 2023 Meta AI Team and the HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch X-MOD model."""
  15. from collections.abc import Callable
  16. import torch
  17. from torch import nn
  18. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  19. from ... import initialization as init
  20. from ...activations import ACT2FN, gelu
  21. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  22. from ...generation import GenerationMixin
  23. from ...masking_utils import create_bidirectional_mask, create_causal_mask
  24. from ...modeling_layers import GradientCheckpointingLayer
  25. from ...modeling_outputs import (
  26. BaseModelOutputWithPastAndCrossAttentions,
  27. BaseModelOutputWithPoolingAndCrossAttentions,
  28. CausalLMOutputWithCrossAttentions,
  29. MaskedLMOutput,
  30. MultipleChoiceModelOutput,
  31. QuestionAnsweringModelOutput,
  32. SequenceClassifierOutput,
  33. TokenClassifierOutput,
  34. )
  35. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  36. from ...processing_utils import Unpack
  37. from ...pytorch_utils import apply_chunking_to_forward
  38. from ...utils import TransformersKwargs, auto_docstring, logging
  39. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  40. from ...utils.output_capturing import capture_outputs
  41. from .configuration_xmod import XmodConfig
  42. logger = logging.get_logger(__name__)
  43. # Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->Xmod
  44. class XmodEmbeddings(nn.Module):
  45. """Construct the embeddings from word, position and token_type embeddings."""
  46. def __init__(self, config):
  47. super().__init__()
  48. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  49. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  50. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  51. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  52. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  53. self.register_buffer(
  54. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  55. )
  56. self.register_buffer(
  57. "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
  58. )
  59. self.padding_idx = config.pad_token_id
  60. self.position_embeddings = nn.Embedding(
  61. config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
  62. )
  63. def forward(
  64. self,
  65. input_ids: torch.LongTensor | None = None,
  66. token_type_ids: torch.LongTensor | None = None,
  67. position_ids: torch.LongTensor | None = None,
  68. inputs_embeds: torch.FloatTensor | None = None,
  69. past_key_values_length: int = 0,
  70. ) -> torch.Tensor:
  71. if position_ids is None:
  72. if input_ids is not None:
  73. # Create the position ids from the input token ids. Any padded tokens remain padded.
  74. position_ids = self.create_position_ids_from_input_ids(
  75. input_ids, self.padding_idx, past_key_values_length
  76. )
  77. else:
  78. position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, self.padding_idx)
  79. if input_ids is not None:
  80. input_shape = input_ids.size()
  81. else:
  82. input_shape = inputs_embeds.size()[:-1]
  83. batch_size, seq_length = input_shape
  84. # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
  85. # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
  86. # issue #5664
  87. if token_type_ids is None:
  88. if hasattr(self, "token_type_ids"):
  89. # NOTE: We assume either pos ids to have bsz == 1 (broadcastable) or bsz == effective bsz (input_shape[0])
  90. buffered_token_type_ids = self.token_type_ids.expand(position_ids.shape[0], -1)
  91. buffered_token_type_ids = torch.gather(buffered_token_type_ids, dim=1, index=position_ids)
  92. token_type_ids = buffered_token_type_ids.expand(batch_size, seq_length)
  93. else:
  94. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  95. if inputs_embeds is None:
  96. inputs_embeds = self.word_embeddings(input_ids)
  97. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  98. embeddings = inputs_embeds + token_type_embeddings
  99. position_embeddings = self.position_embeddings(position_ids)
  100. embeddings = embeddings + position_embeddings
  101. embeddings = self.LayerNorm(embeddings)
  102. embeddings = self.dropout(embeddings)
  103. return embeddings
  104. @staticmethod
  105. def create_position_ids_from_inputs_embeds(inputs_embeds, padding_idx):
  106. """
  107. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
  108. Args:
  109. inputs_embeds: torch.Tensor
  110. Returns: torch.Tensor
  111. """
  112. input_shape = inputs_embeds.size()[:-1]
  113. sequence_length = input_shape[1]
  114. position_ids = torch.arange(
  115. padding_idx + 1, sequence_length + padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
  116. )
  117. return position_ids.unsqueeze(0).expand(input_shape)
  118. @staticmethod
  119. def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
  120. """
  121. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
  122. are ignored. This is modified from fairseq's `utils.make_positions`.
  123. Args:
  124. x: torch.Tensor x:
  125. Returns: torch.Tensor
  126. """
  127. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  128. mask = input_ids.ne(padding_idx).int()
  129. incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
  130. return incremental_indices.long() + padding_idx
  131. # Copied from transformers.models.bert.modeling_bert.eager_attention_forward
  132. def eager_attention_forward(
  133. module: nn.Module,
  134. query: torch.Tensor,
  135. key: torch.Tensor,
  136. value: torch.Tensor,
  137. attention_mask: torch.Tensor | None,
  138. scaling: float | None = None,
  139. dropout: float = 0.0,
  140. **kwargs: Unpack[TransformersKwargs],
  141. ):
  142. if scaling is None:
  143. scaling = query.size(-1) ** -0.5
  144. # Take the dot product between "query" and "key" to get the raw attention scores.
  145. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  146. if attention_mask is not None:
  147. attn_weights = attn_weights + attention_mask
  148. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  149. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  150. attn_output = torch.matmul(attn_weights, value)
  151. attn_output = attn_output.transpose(1, 2).contiguous()
  152. return attn_output, attn_weights
  153. # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->Xmod
  154. class XmodSelfAttention(nn.Module):
  155. def __init__(self, config, is_causal=False, layer_idx=None):
  156. super().__init__()
  157. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  158. raise ValueError(
  159. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  160. f"heads ({config.num_attention_heads})"
  161. )
  162. self.config = config
  163. self.num_attention_heads = config.num_attention_heads
  164. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  165. self.all_head_size = self.num_attention_heads * self.attention_head_size
  166. self.scaling = self.attention_head_size**-0.5
  167. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  168. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  169. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  170. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  171. self.is_decoder = config.is_decoder
  172. self.is_causal = is_causal
  173. self.layer_idx = layer_idx
  174. def forward(
  175. self,
  176. hidden_states: torch.Tensor,
  177. attention_mask: torch.FloatTensor | None = None,
  178. past_key_values: Cache | None = None,
  179. **kwargs: Unpack[TransformersKwargs],
  180. ) -> tuple[torch.Tensor]:
  181. input_shape = hidden_states.shape[:-1]
  182. hidden_shape = (*input_shape, -1, self.attention_head_size)
  183. # get all proj
  184. query_layer = self.query(hidden_states).view(*hidden_shape).transpose(1, 2)
  185. key_layer = self.key(hidden_states).view(*hidden_shape).transpose(1, 2)
  186. value_layer = self.value(hidden_states).view(*hidden_shape).transpose(1, 2)
  187. if past_key_values is not None:
  188. # decoder-only roberta can have a simple dynamic cache for example
  189. current_past_key_values = past_key_values
  190. if isinstance(past_key_values, EncoderDecoderCache):
  191. current_past_key_values = past_key_values.self_attention_cache
  192. # save all key/value_layer to cache to be re-used for fast auto-regressive generation
  193. key_layer, value_layer = current_past_key_values.update(key_layer, value_layer, self.layer_idx)
  194. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  195. self.config._attn_implementation, eager_attention_forward
  196. )
  197. attn_output, attn_weights = attention_interface(
  198. self,
  199. query_layer,
  200. key_layer,
  201. value_layer,
  202. attention_mask,
  203. dropout=0.0 if not self.training else self.dropout.p,
  204. scaling=self.scaling,
  205. **kwargs,
  206. )
  207. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  208. return attn_output, attn_weights
  209. # Copied from transformers.models.bert.modeling_bert.BertCrossAttention with Bert->Xmod
  210. class XmodCrossAttention(nn.Module):
  211. def __init__(self, config, is_causal=False, layer_idx=None):
  212. super().__init__()
  213. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  214. raise ValueError(
  215. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  216. f"heads ({config.num_attention_heads})"
  217. )
  218. self.config = config
  219. self.num_attention_heads = config.num_attention_heads
  220. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  221. self.all_head_size = self.num_attention_heads * self.attention_head_size
  222. self.scaling = self.attention_head_size**-0.5
  223. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  224. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  225. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  226. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  227. self.is_causal = is_causal
  228. self.layer_idx = layer_idx
  229. def forward(
  230. self,
  231. hidden_states: torch.Tensor,
  232. encoder_hidden_states: torch.FloatTensor | None = None,
  233. attention_mask: torch.FloatTensor | None = None,
  234. past_key_values: EncoderDecoderCache | None = None,
  235. **kwargs: Unpack[TransformersKwargs],
  236. ) -> tuple[torch.Tensor]:
  237. # determine input shapes
  238. input_shape = hidden_states.shape[:-1]
  239. hidden_shape = (*input_shape, -1, self.attention_head_size)
  240. # get query proj
  241. query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  242. is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
  243. if past_key_values is not None and is_updated:
  244. # reuse k,v, cross_attentions
  245. key_layer = past_key_values.cross_attention_cache.layers[self.layer_idx].keys
  246. value_layer = past_key_values.cross_attention_cache.layers[self.layer_idx].values
  247. else:
  248. kv_shape = (*encoder_hidden_states.shape[:-1], -1, self.attention_head_size)
  249. key_layer = self.key(encoder_hidden_states).view(kv_shape).transpose(1, 2)
  250. value_layer = self.value(encoder_hidden_states).view(kv_shape).transpose(1, 2)
  251. if past_key_values is not None:
  252. # save all states to the cache
  253. key_layer, value_layer = past_key_values.cross_attention_cache.update(
  254. key_layer, value_layer, self.layer_idx
  255. )
  256. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  257. past_key_values.is_updated[self.layer_idx] = True
  258. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  259. self.config._attn_implementation, eager_attention_forward
  260. )
  261. attn_output, attn_weights = attention_interface(
  262. self,
  263. query_layer,
  264. key_layer,
  265. value_layer,
  266. attention_mask,
  267. dropout=0.0 if not self.training else self.dropout.p,
  268. scaling=self.scaling,
  269. **kwargs,
  270. )
  271. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  272. return attn_output, attn_weights
  273. class XmodSelfOutput(nn.Module):
  274. # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput.__init__
  275. def __init__(self, config):
  276. super().__init__()
  277. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  278. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  279. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  280. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  281. hidden_states = self.dense(hidden_states)
  282. hidden_states = self.dropout(hidden_states)
  283. hidden_states = hidden_states + input_tensor
  284. return hidden_states
  285. class XmodAttention(nn.Module):
  286. def __init__(self, config, is_causal=False, layer_idx=None, is_cross_attention=False):
  287. super().__init__()
  288. self.is_cross_attention = is_cross_attention
  289. attention_class = XmodCrossAttention if is_cross_attention else XmodSelfAttention
  290. self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx)
  291. self.output = XmodSelfOutput(config)
  292. self.pre_norm = config.pre_norm
  293. def forward(
  294. self,
  295. hidden_states: torch.Tensor,
  296. attention_mask: torch.FloatTensor | None = None,
  297. encoder_hidden_states: torch.FloatTensor | None = None,
  298. encoder_attention_mask: torch.FloatTensor | None = None,
  299. past_key_values: tuple[tuple[torch.FloatTensor]] | None = None,
  300. **kwargs: Unpack[TransformersKwargs],
  301. ) -> tuple[torch.Tensor]:
  302. residual = hidden_states
  303. if self.pre_norm:
  304. hidden_states = self.output.LayerNorm(hidden_states)
  305. attention_mask = attention_mask if not self.is_cross_attention else encoder_attention_mask
  306. attention_output, attn_weights = self.self(
  307. hidden_states,
  308. encoder_hidden_states=encoder_hidden_states,
  309. attention_mask=attention_mask,
  310. past_key_values=past_key_values,
  311. **kwargs,
  312. )
  313. attention_output = self.output(attention_output, residual)
  314. if not self.pre_norm:
  315. attention_output = self.output.LayerNorm(attention_output)
  316. return attention_output, attn_weights
  317. # Copied from transformers.models.roberta.modeling_roberta.RobertaIntermediate
  318. class XmodIntermediate(nn.Module):
  319. def __init__(self, config):
  320. super().__init__()
  321. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  322. if isinstance(config.hidden_act, str):
  323. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  324. else:
  325. self.intermediate_act_fn = config.hidden_act
  326. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  327. hidden_states = self.dense(hidden_states)
  328. hidden_states = self.intermediate_act_fn(hidden_states)
  329. return hidden_states
  330. class XmodAdapter(nn.Module):
  331. def __init__(self, config):
  332. super().__init__()
  333. self.bottleneck_size = config.hidden_size // config.adapter_reduction_factor
  334. self.dense1 = nn.Linear(config.hidden_size, self.bottleneck_size)
  335. self.dense2 = nn.Linear(self.bottleneck_size, config.hidden_size)
  336. if isinstance(config.hidden_act, str):
  337. self.adapter_act_fn = ACT2FN[config.hidden_act]
  338. else:
  339. self.adapter_act_fn = config.hidden_act
  340. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  341. hidden_states = self.dense1(hidden_states)
  342. hidden_states = self.adapter_act_fn(hidden_states)
  343. hidden_states = self.dense2(hidden_states)
  344. return hidden_states
  345. class XmodOutput(nn.Module):
  346. def __init__(self, config):
  347. super().__init__()
  348. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  349. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  350. self.ln_before_adapter = config.ln_before_adapter
  351. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  352. if config.adapter_layer_norm:
  353. self.adapter_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  354. else:
  355. self.adapter_layer_norm = None
  356. self.adapter_reuse_layer_norm = config.adapter_reuse_layer_norm
  357. self.adapter_modules = nn.ModuleDict({})
  358. for language in config.languages:
  359. self.adapter_modules[str(language)] = XmodAdapter(config)
  360. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, lang_ids: torch.Tensor) -> torch.Tensor:
  361. hidden_states = self.dense(hidden_states)
  362. hidden_states = self.dropout(hidden_states)
  363. hidden_states = hidden_states + input_tensor
  364. hidden_states = self.lang_adapter(lang_ids, hidden_states)
  365. return hidden_states
  366. def lang_adapter(self, lang_ids: torch.Tensor, hidden_states: torch.Tensor):
  367. if not self.ln_before_adapter:
  368. residual = hidden_states
  369. if self.adapter_layer_norm is not None:
  370. hidden_states = self.adapter_layer_norm(hidden_states)
  371. elif self.adapter_reuse_layer_norm:
  372. hidden_states = self.LayerNorm(hidden_states)
  373. if self.ln_before_adapter:
  374. residual = hidden_states
  375. new_hidden_states = torch.zeros_like(hidden_states)
  376. for adapter_idx, lang_key in enumerate(self.adapter_modules.keys()):
  377. lang_mask = lang_ids == adapter_idx
  378. lang_hidden_states = hidden_states[lang_mask]
  379. adapted_lang_hidden_states = self.adapter_modules[lang_key](lang_hidden_states)
  380. new_hidden_states[lang_mask] = adapted_lang_hidden_states
  381. hidden_states = self.dropout(new_hidden_states)
  382. hidden_states += residual
  383. return hidden_states
  384. class XmodLayer(GradientCheckpointingLayer):
  385. def __init__(self, config, layer_idx=None):
  386. super().__init__()
  387. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  388. self.seq_len_dim = 1
  389. self.attention = XmodAttention(config, is_causal=config.is_decoder, layer_idx=layer_idx)
  390. self.is_decoder = config.is_decoder
  391. self.add_cross_attention = config.add_cross_attention
  392. if self.add_cross_attention:
  393. if not self.is_decoder:
  394. raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
  395. self.crossattention = XmodAttention(
  396. config,
  397. is_causal=False,
  398. layer_idx=layer_idx,
  399. is_cross_attention=True,
  400. )
  401. self.intermediate = XmodIntermediate(config)
  402. self.output = XmodOutput(config)
  403. self.pre_norm = config.pre_norm
  404. def forward(
  405. self,
  406. hidden_states: torch.Tensor,
  407. lang_ids: torch.Tensor,
  408. attention_mask: torch.FloatTensor | None = None,
  409. encoder_hidden_states: torch.FloatTensor | None = None,
  410. encoder_attention_mask: torch.FloatTensor | None = None,
  411. past_key_values: tuple[tuple[torch.FloatTensor]] | None = None,
  412. **kwargs: Unpack[TransformersKwargs],
  413. ) -> torch.Tensor:
  414. self_attention_output, _ = self.attention(
  415. hidden_states,
  416. attention_mask,
  417. past_key_values=past_key_values,
  418. **kwargs,
  419. )
  420. attention_output = self_attention_output
  421. if self.is_decoder and encoder_hidden_states is not None:
  422. if not hasattr(self, "crossattention"):
  423. raise ValueError(
  424. f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
  425. " by setting `config.add_cross_attention=True`"
  426. )
  427. cross_attention_output, _ = self.crossattention(
  428. attention_output,
  429. None, # attention_mask
  430. encoder_hidden_states,
  431. encoder_attention_mask,
  432. past_key_values=past_key_values,
  433. **kwargs,
  434. )
  435. attention_output = cross_attention_output
  436. residual = attention_output
  437. if self.pre_norm:
  438. attention_output = self.output.LayerNorm(attention_output)
  439. intermediate_output = apply_chunking_to_forward(
  440. self.feed_forward_chunk,
  441. self.chunk_size_feed_forward,
  442. self.seq_len_dim,
  443. attention_output,
  444. )
  445. layer_output = self.output(intermediate_output, residual, lang_ids)
  446. if not self.pre_norm:
  447. layer_output = self.output.LayerNorm(layer_output)
  448. return layer_output
  449. def feed_forward_chunk(self, attention_output):
  450. return self.intermediate(attention_output)
  451. class XmodEncoder(nn.Module):
  452. def __init__(self, config):
  453. super().__init__()
  454. self.config = config
  455. self.layer = nn.ModuleList([XmodLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
  456. self.is_pre_norm = config.pre_norm
  457. if self.is_pre_norm:
  458. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  459. def forward(
  460. self,
  461. hidden_states: torch.Tensor,
  462. lang_ids: torch.Tensor,
  463. attention_mask: torch.FloatTensor | None = None,
  464. encoder_hidden_states: torch.FloatTensor | None = None,
  465. encoder_attention_mask: torch.FloatTensor | None = None,
  466. past_key_values: tuple[tuple[torch.FloatTensor]] | None = None,
  467. use_cache: bool | None = None,
  468. **kwargs: Unpack[TransformersKwargs],
  469. ) -> tuple[torch.Tensor] | BaseModelOutputWithPastAndCrossAttentions:
  470. for i, layer_module in enumerate(self.layer):
  471. hidden_states = layer_module(
  472. hidden_states,
  473. lang_ids,
  474. attention_mask,
  475. encoder_hidden_states,
  476. encoder_attention_mask,
  477. past_key_values,
  478. **kwargs,
  479. )
  480. if self.is_pre_norm:
  481. hidden_states = self.LayerNorm(hidden_states)
  482. return BaseModelOutputWithPastAndCrossAttentions(
  483. last_hidden_state=hidden_states,
  484. past_key_values=past_key_values if use_cache else None,
  485. )
  486. # Copied from transformers.models.roberta.modeling_roberta.RobertaPooler
  487. class XmodPooler(nn.Module):
  488. def __init__(self, config):
  489. super().__init__()
  490. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  491. self.activation = nn.Tanh()
  492. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  493. # We "pool" the model by simply taking the hidden state corresponding
  494. # to the first token.
  495. first_token_tensor = hidden_states[:, 0]
  496. pooled_output = self.dense(first_token_tensor)
  497. pooled_output = self.activation(pooled_output)
  498. return pooled_output
  499. @auto_docstring
  500. class XmodPreTrainedModel(PreTrainedModel):
  501. config_class = XmodConfig
  502. base_model_prefix = "roberta"
  503. supports_gradient_checkpointing = True
  504. no_split_modules = ["XmodEmbeddings", "XmodSelfAttention", "XmodCrossAttention"]
  505. _supports_flash_attn = True
  506. _supports_sdpa = True
  507. _supports_flex_attn = True
  508. _supports_attention_backend = True
  509. _can_record_outputs = {
  510. "hidden_states": XmodLayer,
  511. "attentions": XmodSelfAttention,
  512. "cross_attentions": XmodCrossAttention,
  513. }
  514. @torch.no_grad()
  515. def _init_weights(self, module):
  516. """Initialize the weights"""
  517. super()._init_weights(module)
  518. if isinstance(module, XmodLMHead):
  519. init.zeros_(module.bias)
  520. elif isinstance(module, XmodEmbeddings):
  521. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  522. init.zeros_(module.token_type_ids)
  523. def set_default_language(self, language: str):
  524. """
  525. Set the default language code for the model. This is used when the language is not specified in the input.
  526. Args:
  527. language (`str`): The language code, such as `"en_XX"` or `"de_DE"`.
  528. """
  529. if language not in self.config.languages:
  530. raise ValueError(
  531. f"{self} does not have an adapter for {language}. Supported languages: {list(self.config.languages)}"
  532. )
  533. self.config.default_language = language
  534. def freeze_embeddings_and_language_adapters(self):
  535. """
  536. Freeze the embeddings and language adapters of the model. Usually, this is applied before the model is
  537. fine-tuned on a downstream task.
  538. """
  539. logger.info("Freezing embeddings")
  540. for parameter in self.roberta.embeddings.parameters():
  541. parameter.requires_grad = False
  542. logger.info("Freezing adapters")
  543. for layer in self.roberta.encoder.layer:
  544. if layer.output.adapter_layer_norm is not None:
  545. for parameter in layer.output.adapter_layer_norm.parameters():
  546. parameter.requires_grad = False
  547. for parameter in layer.output.adapter_modules.parameters():
  548. parameter.requires_grad = False
  549. @auto_docstring(
  550. custom_intro="""
  551. The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
  552. cross-attention is added between the self-attention layers, following the architecture described in *Attention is
  553. all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
  554. Kaiser and Illia Polosukhin.
  555. To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
  556. to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
  557. `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
  558. .. _*Attention is all you need*: https://huggingface.co/papers/1706.03762
  559. """
  560. )
  561. class XmodModel(XmodPreTrainedModel):
  562. def __init__(self, config, add_pooling_layer=True):
  563. r"""
  564. add_pooling_layer (bool, *optional*, defaults to `True`):
  565. Whether to add a pooling layer
  566. """
  567. super().__init__(config)
  568. self.config = config
  569. self.gradient_checkpointing = False
  570. self.embeddings = XmodEmbeddings(config)
  571. self.encoder = XmodEncoder(config)
  572. self.pooler = XmodPooler(config) if add_pooling_layer else None
  573. # Initialize weights and apply final processing
  574. self.post_init()
  575. # Copied from transformers.models.roberta.modeling_roberta.RobertaModel.get_input_embeddings
  576. def get_input_embeddings(self):
  577. return self.embeddings.word_embeddings
  578. # Copied from transformers.models.roberta.modeling_roberta.RobertaModel.set_input_embeddings
  579. def set_input_embeddings(self, value):
  580. self.embeddings.word_embeddings = value
  581. @merge_with_config_defaults
  582. @capture_outputs
  583. @auto_docstring
  584. def forward(
  585. self,
  586. input_ids: torch.Tensor | None = None,
  587. lang_ids: torch.LongTensor | None = None,
  588. attention_mask: torch.Tensor | None = None,
  589. token_type_ids: torch.Tensor | None = None,
  590. position_ids: torch.Tensor | None = None,
  591. inputs_embeds: torch.Tensor | None = None,
  592. encoder_hidden_states: torch.Tensor | None = None,
  593. encoder_attention_mask: torch.Tensor | None = None,
  594. past_key_values: list[torch.FloatTensor] | None = None,
  595. use_cache: bool | None = None,
  596. **kwargs: Unpack[TransformersKwargs],
  597. ) -> tuple[torch.Tensor] | BaseModelOutputWithPoolingAndCrossAttentions:
  598. r"""
  599. lang_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  600. Indices of the language adapters that should be activated for each sample, respectively. Default: the index
  601. that corresponds to `self.config.default_language`.
  602. """
  603. if (input_ids is None) ^ (inputs_embeds is not None):
  604. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  605. if self.config.is_decoder:
  606. use_cache = use_cache if use_cache is not None else self.config.use_cache
  607. else:
  608. use_cache = False
  609. if use_cache and past_key_values is None:
  610. past_key_values = (
  611. EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  612. if encoder_hidden_states is not None or self.config.is_encoder_decoder
  613. else DynamicCache(config=self.config)
  614. )
  615. batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
  616. device = input_ids.device if input_ids is not None else inputs_embeds.device
  617. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  618. if lang_ids is None:
  619. if self.config.default_language is None:
  620. raise ValueError("Input language unknown. Please call `XmodPreTrainedModel.set_default_language()`")
  621. adapter_languages = list(self.encoder.layer[0].output.adapter_modules.keys())
  622. default_lang_id = adapter_languages.index(self.config.default_language)
  623. lang_ids = default_lang_id * torch.ones(batch_size, device=device)
  624. embedding_output = self.embeddings(
  625. input_ids=input_ids,
  626. position_ids=position_ids,
  627. token_type_ids=token_type_ids,
  628. inputs_embeds=inputs_embeds,
  629. past_key_values_length=past_key_values_length,
  630. )
  631. attention_mask, encoder_attention_mask = self._create_attention_masks(
  632. attention_mask=attention_mask,
  633. encoder_attention_mask=encoder_attention_mask,
  634. embedding_output=embedding_output,
  635. encoder_hidden_states=encoder_hidden_states,
  636. past_key_values=past_key_values,
  637. )
  638. encoder_outputs = self.encoder(
  639. embedding_output,
  640. lang_ids=lang_ids,
  641. attention_mask=attention_mask,
  642. encoder_hidden_states=encoder_hidden_states,
  643. encoder_attention_mask=encoder_attention_mask,
  644. past_key_values=past_key_values,
  645. use_cache=use_cache,
  646. position_ids=position_ids,
  647. **kwargs,
  648. )
  649. sequence_output = encoder_outputs[0]
  650. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  651. return BaseModelOutputWithPoolingAndCrossAttentions(
  652. last_hidden_state=sequence_output,
  653. pooler_output=pooled_output,
  654. past_key_values=encoder_outputs.past_key_values,
  655. )
  656. # Copied from transformers.models.bert.modeling_bert.BertModel._create_attention_masks
  657. def _create_attention_masks(
  658. self,
  659. attention_mask,
  660. encoder_attention_mask,
  661. embedding_output,
  662. encoder_hidden_states,
  663. past_key_values,
  664. ):
  665. if self.config.is_decoder:
  666. attention_mask = create_causal_mask(
  667. config=self.config,
  668. inputs_embeds=embedding_output,
  669. attention_mask=attention_mask,
  670. past_key_values=past_key_values,
  671. )
  672. else:
  673. attention_mask = create_bidirectional_mask(
  674. config=self.config,
  675. inputs_embeds=embedding_output,
  676. attention_mask=attention_mask,
  677. )
  678. if encoder_attention_mask is not None:
  679. encoder_attention_mask = create_bidirectional_mask(
  680. config=self.config,
  681. inputs_embeds=embedding_output,
  682. attention_mask=encoder_attention_mask,
  683. encoder_hidden_states=encoder_hidden_states,
  684. )
  685. return attention_mask, encoder_attention_mask
  686. @auto_docstring(
  687. custom_intro="""
  688. X-MOD Model with a `language modeling` head on top for CLM fine-tuning.
  689. """
  690. )
  691. class XmodForCausalLM(XmodPreTrainedModel, GenerationMixin):
  692. _tied_weights_keys = {
  693. "lm_head.decoder.weight": "roberta.embeddings.word_embeddings.weight",
  694. "lm_head.decoder.bias": "lm_head.bias",
  695. }
  696. # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM.__init__ with Roberta->Xmod
  697. def __init__(self, config):
  698. super().__init__(config)
  699. if not config.is_decoder:
  700. logger.warning("If you want to use `XmodLMHeadModel` as a standalone, add `is_decoder=True.`")
  701. self.roberta = XmodModel(config, add_pooling_layer=False)
  702. self.lm_head = XmodLMHead(config)
  703. # Initialize weights and apply final processing
  704. self.post_init()
  705. # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM.get_output_embeddings
  706. def get_output_embeddings(self):
  707. return self.lm_head.decoder
  708. # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM.set_output_embeddings
  709. def set_output_embeddings(self, new_embeddings):
  710. self.lm_head.decoder = new_embeddings
  711. @can_return_tuple
  712. @auto_docstring
  713. def forward(
  714. self,
  715. input_ids: torch.LongTensor | None = None,
  716. lang_ids: torch.LongTensor | None = None,
  717. attention_mask: torch.FloatTensor | None = None,
  718. token_type_ids: torch.LongTensor | None = None,
  719. position_ids: torch.LongTensor | None = None,
  720. inputs_embeds: torch.FloatTensor | None = None,
  721. encoder_hidden_states: torch.FloatTensor | None = None,
  722. encoder_attention_mask: torch.FloatTensor | None = None,
  723. labels: torch.LongTensor | None = None,
  724. past_key_values: tuple[tuple[torch.FloatTensor]] | None = None,
  725. use_cache: bool | None = None,
  726. logits_to_keep: int | torch.Tensor = 0,
  727. **kwargs: Unpack[TransformersKwargs],
  728. ) -> tuple[torch.Tensor] | CausalLMOutputWithCrossAttentions:
  729. r"""
  730. lang_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  731. Indices of the language adapters that should be activated for each sample, respectively. Default: the index
  732. that corresponds to `self.config.default_language`.
  733. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  734. Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
  735. `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
  736. ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  737. Example:
  738. ```python
  739. >>> from transformers import AutoTokenizer, XmodForCausalLM, AutoConfig
  740. >>> import torch
  741. >>> tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")
  742. >>> config = AutoConfig.from_pretrained("facebook/xmod-base")
  743. >>> config.is_decoder = True
  744. >>> model = XmodForCausalLM.from_pretrained("facebook/xmod-base", config=config)
  745. >>> model.set_default_language("en_XX")
  746. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  747. >>> outputs = model(**inputs)
  748. >>> prediction_logits = outputs.logits
  749. ```"""
  750. if labels is not None:
  751. use_cache = False
  752. outputs: BaseModelOutputWithPoolingAndCrossAttentions = self.roberta(
  753. input_ids,
  754. lang_ids=lang_ids,
  755. attention_mask=attention_mask,
  756. token_type_ids=token_type_ids,
  757. position_ids=position_ids,
  758. inputs_embeds=inputs_embeds,
  759. encoder_hidden_states=encoder_hidden_states,
  760. encoder_attention_mask=encoder_attention_mask,
  761. past_key_values=past_key_values,
  762. use_cache=use_cache,
  763. return_dict=True,
  764. **kwargs,
  765. )
  766. hidden_states = outputs.last_hidden_state
  767. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  768. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  769. logits = self.lm_head(hidden_states[:, slice_indices, :])
  770. loss = None
  771. if labels is not None:
  772. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  773. return CausalLMOutputWithCrossAttentions(
  774. loss=loss,
  775. logits=logits,
  776. past_key_values=outputs.past_key_values,
  777. hidden_states=outputs.hidden_states,
  778. attentions=outputs.attentions,
  779. cross_attentions=outputs.cross_attentions,
  780. )
  781. @auto_docstring
  782. class XmodForMaskedLM(XmodPreTrainedModel):
  783. _tied_weights_keys = {
  784. "lm_head.decoder.weight": "roberta.embeddings.word_embeddings.weight",
  785. "lm_head.decoder.bias": "lm_head.bias",
  786. }
  787. # Copied from transformers.models.roberta.modeling_roberta.RobertaForMaskedLM.__init__ with Roberta->Xmod
  788. def __init__(self, config):
  789. super().__init__(config)
  790. if config.is_decoder:
  791. logger.warning(
  792. "If you want to use `XmodForMaskedLM` make sure `config.is_decoder=False` for "
  793. "bi-directional self-attention."
  794. )
  795. self.roberta = XmodModel(config, add_pooling_layer=False)
  796. self.lm_head = XmodLMHead(config)
  797. # Initialize weights and apply final processing
  798. self.post_init()
  799. # Copied from transformers.models.roberta.modeling_roberta.RobertaForMaskedLM.get_output_embeddings
  800. def get_output_embeddings(self):
  801. return self.lm_head.decoder
  802. # Copied from transformers.models.roberta.modeling_roberta.RobertaForMaskedLM.set_output_embeddings
  803. def set_output_embeddings(self, new_embeddings):
  804. self.lm_head.decoder = new_embeddings
  805. @can_return_tuple
  806. @auto_docstring
  807. def forward(
  808. self,
  809. input_ids: torch.LongTensor | None = None,
  810. lang_ids: torch.LongTensor | None = None,
  811. attention_mask: torch.FloatTensor | None = None,
  812. token_type_ids: torch.LongTensor | None = None,
  813. position_ids: torch.LongTensor | None = None,
  814. inputs_embeds: torch.FloatTensor | None = None,
  815. encoder_hidden_states: torch.FloatTensor | None = None,
  816. encoder_attention_mask: torch.FloatTensor | None = None,
  817. labels: torch.LongTensor | None = None,
  818. **kwargs: Unpack[TransformersKwargs],
  819. ) -> tuple[torch.Tensor] | MaskedLMOutput:
  820. r"""
  821. lang_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  822. Indices of the language adapters that should be activated for each sample, respectively. Default: the index
  823. that corresponds to `self.config.default_language`.
  824. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  825. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  826. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  827. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  828. """
  829. outputs = self.roberta(
  830. input_ids,
  831. lang_ids=lang_ids,
  832. attention_mask=attention_mask,
  833. token_type_ids=token_type_ids,
  834. position_ids=position_ids,
  835. inputs_embeds=inputs_embeds,
  836. encoder_hidden_states=encoder_hidden_states,
  837. encoder_attention_mask=encoder_attention_mask,
  838. return_dict=True,
  839. **kwargs,
  840. )
  841. sequence_output = outputs[0]
  842. prediction_scores = self.lm_head(sequence_output)
  843. masked_lm_loss = None
  844. if labels is not None:
  845. loss_fct = CrossEntropyLoss()
  846. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  847. return MaskedLMOutput(
  848. loss=masked_lm_loss,
  849. logits=prediction_scores,
  850. hidden_states=outputs.hidden_states,
  851. attentions=outputs.attentions,
  852. )
  853. # Copied from transformers.models.roberta.modeling_roberta.RobertaLMHead
  854. class XmodLMHead(nn.Module):
  855. """Roberta Head for masked language modeling."""
  856. def __init__(self, config):
  857. super().__init__()
  858. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  859. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  860. self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
  861. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  862. def forward(self, features, **kwargs):
  863. x = self.dense(features)
  864. x = gelu(x)
  865. x = self.layer_norm(x)
  866. # project back to size of vocabulary with bias
  867. x = self.decoder(x)
  868. return x
  869. @auto_docstring(
  870. custom_intro="""
  871. X-MOD Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
  872. output) e.g. for GLUE tasks.
  873. """
  874. )
  875. class XmodForSequenceClassification(XmodPreTrainedModel):
  876. # Copied from transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification.__init__ with Roberta->Xmod
  877. def __init__(self, config):
  878. super().__init__(config)
  879. self.num_labels = config.num_labels
  880. self.config = config
  881. self.roberta = XmodModel(config, add_pooling_layer=False)
  882. self.classifier = XmodClassificationHead(config)
  883. # Initialize weights and apply final processing
  884. self.post_init()
  885. @can_return_tuple
  886. @auto_docstring
  887. def forward(
  888. self,
  889. input_ids: torch.LongTensor | None = None,
  890. lang_ids: torch.LongTensor | None = None,
  891. attention_mask: torch.FloatTensor | None = None,
  892. token_type_ids: torch.LongTensor | None = None,
  893. position_ids: torch.LongTensor | None = None,
  894. inputs_embeds: torch.FloatTensor | None = None,
  895. labels: torch.LongTensor | None = None,
  896. **kwargs: Unpack[TransformersKwargs],
  897. ) -> tuple[torch.Tensor] | SequenceClassifierOutput:
  898. r"""
  899. lang_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  900. Indices of the language adapters that should be activated for each sample, respectively. Default: the index
  901. that corresponds to `self.config.default_language`.
  902. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  903. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  904. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  905. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  906. """
  907. outputs = self.roberta(
  908. input_ids,
  909. lang_ids=lang_ids,
  910. attention_mask=attention_mask,
  911. token_type_ids=token_type_ids,
  912. position_ids=position_ids,
  913. inputs_embeds=inputs_embeds,
  914. return_dict=True,
  915. **kwargs,
  916. )
  917. sequence_output = outputs[0]
  918. logits = self.classifier(sequence_output)
  919. loss = None
  920. if labels is not None:
  921. if self.config.problem_type is None:
  922. if self.num_labels == 1:
  923. self.config.problem_type = "regression"
  924. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  925. self.config.problem_type = "single_label_classification"
  926. else:
  927. self.config.problem_type = "multi_label_classification"
  928. if self.config.problem_type == "regression":
  929. loss_fct = MSELoss()
  930. if self.num_labels == 1:
  931. loss = loss_fct(logits.squeeze(), labels.squeeze())
  932. else:
  933. loss = loss_fct(logits, labels)
  934. elif self.config.problem_type == "single_label_classification":
  935. loss_fct = CrossEntropyLoss()
  936. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  937. elif self.config.problem_type == "multi_label_classification":
  938. loss_fct = BCEWithLogitsLoss()
  939. loss = loss_fct(logits, labels)
  940. return SequenceClassifierOutput(
  941. loss=loss,
  942. logits=logits,
  943. hidden_states=outputs.hidden_states,
  944. attentions=outputs.attentions,
  945. )
  946. @auto_docstring
  947. class XmodForMultipleChoice(XmodPreTrainedModel):
  948. # Copied from transformers.models.roberta.modeling_roberta.RobertaForMultipleChoice.__init__ with Roberta->Xmod
  949. def __init__(self, config):
  950. super().__init__(config)
  951. self.roberta = XmodModel(config)
  952. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  953. self.classifier = nn.Linear(config.hidden_size, 1)
  954. # Initialize weights and apply final processing
  955. self.post_init()
  956. @can_return_tuple
  957. @auto_docstring
  958. def forward(
  959. self,
  960. input_ids: torch.LongTensor | None = None,
  961. lang_ids: torch.LongTensor | None = None,
  962. token_type_ids: torch.LongTensor | None = None,
  963. attention_mask: torch.FloatTensor | None = None,
  964. labels: torch.LongTensor | None = None,
  965. position_ids: torch.LongTensor | None = None,
  966. inputs_embeds: torch.FloatTensor | None = None,
  967. **kwargs: Unpack[TransformersKwargs],
  968. ) -> tuple[torch.Tensor] | MultipleChoiceModelOutput:
  969. r"""
  970. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  971. Indices of input sequence tokens in the vocabulary.
  972. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  973. [`PreTrainedTokenizer.__call__`] for details.
  974. [What are input IDs?](../glossary#input-ids)
  975. lang_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  976. Indices of the language adapters that should be activated for each sample, respectively. Default: the index
  977. that corresponds to `self.config.default_language`.
  978. token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  979. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  980. 1]`:
  981. - 0 corresponds to a *sentence A* token,
  982. - 1 corresponds to a *sentence B* token.
  983. [What are token type IDs?](../glossary#token-type-ids)
  984. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  985. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  986. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  987. `input_ids` above)
  988. position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  989. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  990. config.max_position_embeddings - 1]`.
  991. [What are position IDs?](../glossary#position-ids)
  992. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  993. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  994. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  995. model's internal embedding lookup matrix.
  996. """
  997. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  998. flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  999. flat_lang_ids = lang_ids.repeat(input_ids.size(0) * input_ids.size(1)) if lang_ids is not None else None
  1000. flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  1001. flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  1002. flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  1003. flat_inputs_embeds = (
  1004. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  1005. if inputs_embeds is not None
  1006. else None
  1007. )
  1008. outputs = self.roberta(
  1009. flat_input_ids,
  1010. lang_ids=flat_lang_ids,
  1011. position_ids=flat_position_ids,
  1012. token_type_ids=flat_token_type_ids,
  1013. attention_mask=flat_attention_mask,
  1014. inputs_embeds=flat_inputs_embeds,
  1015. return_dict=True,
  1016. **kwargs,
  1017. )
  1018. pooled_output = outputs[1]
  1019. pooled_output = self.dropout(pooled_output)
  1020. logits = self.classifier(pooled_output)
  1021. reshaped_logits = logits.view(-1, num_choices)
  1022. loss = None
  1023. if labels is not None:
  1024. loss_fct = CrossEntropyLoss()
  1025. loss = loss_fct(reshaped_logits, labels)
  1026. return MultipleChoiceModelOutput(
  1027. loss=loss,
  1028. logits=reshaped_logits,
  1029. hidden_states=outputs.hidden_states,
  1030. attentions=outputs.attentions,
  1031. )
  1032. @auto_docstring
  1033. class XmodForTokenClassification(XmodPreTrainedModel):
  1034. # Copied from transformers.models.roberta.modeling_roberta.RobertaForTokenClassification.__init__ with Roberta->Xmod
  1035. def __init__(self, config):
  1036. super().__init__(config)
  1037. self.num_labels = config.num_labels
  1038. self.roberta = XmodModel(config, add_pooling_layer=False)
  1039. classifier_dropout = (
  1040. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  1041. )
  1042. self.dropout = nn.Dropout(classifier_dropout)
  1043. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1044. # Initialize weights and apply final processing
  1045. self.post_init()
  1046. @can_return_tuple
  1047. @auto_docstring
  1048. def forward(
  1049. self,
  1050. input_ids: torch.LongTensor | None = None,
  1051. lang_ids: torch.LongTensor | None = None,
  1052. attention_mask: torch.FloatTensor | None = None,
  1053. token_type_ids: torch.LongTensor | None = None,
  1054. position_ids: torch.LongTensor | None = None,
  1055. inputs_embeds: torch.FloatTensor | None = None,
  1056. labels: torch.LongTensor | None = None,
  1057. **kwargs: Unpack[TransformersKwargs],
  1058. ) -> tuple[torch.Tensor] | TokenClassifierOutput:
  1059. r"""
  1060. lang_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1061. Indices of the language adapters that should be activated for each sample, respectively. Default: the index
  1062. that corresponds to `self.config.default_language`.
  1063. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1064. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1065. """
  1066. outputs = self.roberta(
  1067. input_ids,
  1068. lang_ids=lang_ids,
  1069. attention_mask=attention_mask,
  1070. token_type_ids=token_type_ids,
  1071. position_ids=position_ids,
  1072. inputs_embeds=inputs_embeds,
  1073. return_dict=True,
  1074. **kwargs,
  1075. )
  1076. sequence_output = outputs[0]
  1077. sequence_output = self.dropout(sequence_output)
  1078. logits = self.classifier(sequence_output)
  1079. loss = None
  1080. if labels is not None:
  1081. loss_fct = CrossEntropyLoss()
  1082. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1083. return TokenClassifierOutput(
  1084. loss=loss,
  1085. logits=logits,
  1086. hidden_states=outputs.hidden_states,
  1087. attentions=outputs.attentions,
  1088. )
  1089. # Copied from transformers.models.roberta.modeling_roberta.RobertaClassificationHead
  1090. class XmodClassificationHead(nn.Module):
  1091. """Head for sentence-level classification tasks."""
  1092. def __init__(self, config):
  1093. super().__init__()
  1094. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  1095. classifier_dropout = (
  1096. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  1097. )
  1098. self.dropout = nn.Dropout(classifier_dropout)
  1099. self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
  1100. def forward(self, features, **kwargs):
  1101. x = features[:, 0, :] # take <s> token (equiv. to [CLS])
  1102. x = self.dropout(x)
  1103. x = self.dense(x)
  1104. x = torch.tanh(x)
  1105. x = self.dropout(x)
  1106. x = self.out_proj(x)
  1107. return x
  1108. @auto_docstring
  1109. class XmodForQuestionAnswering(XmodPreTrainedModel):
  1110. # Copied from transformers.models.roberta.modeling_roberta.RobertaForQuestionAnswering.__init__ with Roberta->Xmod
  1111. def __init__(self, config):
  1112. super().__init__(config)
  1113. self.num_labels = config.num_labels
  1114. self.roberta = XmodModel(config, add_pooling_layer=False)
  1115. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  1116. # Initialize weights and apply final processing
  1117. self.post_init()
  1118. @can_return_tuple
  1119. @auto_docstring
  1120. def forward(
  1121. self,
  1122. input_ids: torch.LongTensor | None = None,
  1123. lang_ids: torch.LongTensor | None = None,
  1124. attention_mask: torch.FloatTensor | None = None,
  1125. token_type_ids: torch.LongTensor | None = None,
  1126. position_ids: torch.LongTensor | None = None,
  1127. inputs_embeds: torch.FloatTensor | None = None,
  1128. start_positions: torch.LongTensor | None = None,
  1129. end_positions: torch.LongTensor | None = None,
  1130. **kwargs: Unpack[TransformersKwargs],
  1131. ) -> tuple[torch.Tensor] | QuestionAnsweringModelOutput:
  1132. r"""
  1133. lang_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1134. Indices of the language adapters that should be activated for each sample, respectively. Default: the index
  1135. that corresponds to `self.config.default_language`.
  1136. """
  1137. outputs = self.roberta(
  1138. input_ids,
  1139. lang_ids=lang_ids,
  1140. attention_mask=attention_mask,
  1141. token_type_ids=token_type_ids,
  1142. position_ids=position_ids,
  1143. inputs_embeds=inputs_embeds,
  1144. return_dict=True,
  1145. **kwargs,
  1146. )
  1147. sequence_output = outputs[0]
  1148. logits = self.qa_outputs(sequence_output)
  1149. start_logits, end_logits = logits.split(1, dim=-1)
  1150. start_logits = start_logits.squeeze(-1).contiguous()
  1151. end_logits = end_logits.squeeze(-1).contiguous()
  1152. total_loss = None
  1153. if start_positions is not None and end_positions is not None:
  1154. # If we are on multi-GPU, split add a dimension
  1155. if len(start_positions.size()) > 1:
  1156. start_positions = start_positions.squeeze(-1)
  1157. if len(end_positions.size()) > 1:
  1158. end_positions = end_positions.squeeze(-1)
  1159. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1160. ignored_index = start_logits.size(1)
  1161. start_positions = start_positions.clamp(0, ignored_index)
  1162. end_positions = end_positions.clamp(0, ignored_index)
  1163. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1164. start_loss = loss_fct(start_logits, start_positions)
  1165. end_loss = loss_fct(end_logits, end_positions)
  1166. total_loss = (start_loss + end_loss) / 2
  1167. return QuestionAnsweringModelOutput(
  1168. loss=total_loss,
  1169. start_logits=start_logits,
  1170. end_logits=end_logits,
  1171. hidden_states=outputs.hidden_states,
  1172. attentions=outputs.attentions,
  1173. )
  1174. __all__ = [
  1175. "XmodForCausalLM",
  1176. "XmodForMaskedLM",
  1177. "XmodForMultipleChoice",
  1178. "XmodForQuestionAnswering",
  1179. "XmodForSequenceClassification",
  1180. "XmodForTokenClassification",
  1181. "XmodModel",
  1182. "XmodPreTrainedModel",
  1183. ]