modeling_mobilebert.py 50 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251
  1. # MIT License
  2. #
  3. # Copyright (c) 2020 The Google AI Language Team Authors, The HuggingFace Inc. team and github/lonePatient
  4. #
  5. # Permission is hereby granted, free of charge, to any person obtaining a copy
  6. # of this software and associated documentation files (the "Software"), to deal
  7. # in the Software without restriction, including without limitation the rights
  8. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  9. # copies of the Software, and to permit persons to whom the Software is
  10. # furnished to do so, subject to the following conditions:
  11. #
  12. # The above copyright notice and this permission notice shall be included in all
  13. # copies or substantial portions of the Software.
  14. #
  15. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  16. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  17. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  18. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  19. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  20. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  21. # SOFTWARE.
  22. from collections.abc import Callable
  23. from dataclasses import dataclass
  24. import torch
  25. from torch import nn
  26. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  27. from ... import initialization as init
  28. from ...activations import ACT2FN
  29. from ...masking_utils import create_bidirectional_mask
  30. from ...modeling_layers import GradientCheckpointingLayer
  31. from ...modeling_outputs import (
  32. BaseModelOutput,
  33. BaseModelOutputWithPooling,
  34. MaskedLMOutput,
  35. MultipleChoiceModelOutput,
  36. NextSentencePredictorOutput,
  37. QuestionAnsweringModelOutput,
  38. SequenceClassifierOutput,
  39. TokenClassifierOutput,
  40. )
  41. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  42. from ...processing_utils import Unpack
  43. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, logging
  44. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  45. from ...utils.output_capturing import capture_outputs
  46. from .configuration_mobilebert import MobileBertConfig
  47. logger = logging.get_logger(__name__)
  48. class NoNorm(nn.Module):
  49. def __init__(self, feat_size, eps=None):
  50. super().__init__()
  51. self.bias = nn.Parameter(torch.zeros(feat_size))
  52. self.weight = nn.Parameter(torch.ones(feat_size))
  53. def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
  54. return input_tensor * self.weight + self.bias
  55. NORM2FN = {"layer_norm": nn.LayerNorm, "no_norm": NoNorm}
  56. class MobileBertEmbeddings(nn.Module):
  57. """Construct the embeddings from word, position and token_type embeddings."""
  58. def __init__(self, config):
  59. super().__init__()
  60. self.trigram_input = config.trigram_input
  61. self.embedding_size = config.embedding_size
  62. self.hidden_size = config.hidden_size
  63. self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
  64. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  65. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  66. embed_dim_multiplier = 3 if self.trigram_input else 1
  67. embedded_input_size = self.embedding_size * embed_dim_multiplier
  68. self.embedding_transformation = nn.Linear(embedded_input_size, config.hidden_size)
  69. self.LayerNorm = NORM2FN[config.normalization_type](config.hidden_size)
  70. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  71. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  72. self.register_buffer(
  73. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  74. )
  75. def forward(
  76. self,
  77. input_ids: torch.LongTensor | None = None,
  78. token_type_ids: torch.LongTensor | None = None,
  79. position_ids: torch.LongTensor | None = None,
  80. inputs_embeds: torch.FloatTensor | None = None,
  81. ) -> torch.Tensor:
  82. if input_ids is not None:
  83. input_shape = input_ids.size()
  84. else:
  85. input_shape = inputs_embeds.size()[:-1]
  86. seq_length = input_shape[1]
  87. if position_ids is None:
  88. position_ids = self.position_ids[:, :seq_length]
  89. if token_type_ids is None:
  90. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  91. if inputs_embeds is None:
  92. inputs_embeds = self.word_embeddings(input_ids)
  93. if self.trigram_input:
  94. # From the paper MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited
  95. # Devices (https://huggingface.co/papers/2004.02984)
  96. #
  97. # The embedding table in BERT models accounts for a substantial proportion of model size. To compress
  98. # the embedding layer, we reduce the embedding dimension to 128 in MobileBERT.
  99. # Then, we apply a 1D convolution with kernel size 3 on the raw token embedding to produce a 512
  100. # dimensional output.
  101. inputs_embeds = torch.cat(
  102. [
  103. nn.functional.pad(inputs_embeds[:, 1:], [0, 0, 0, 1, 0, 0], value=0.0),
  104. inputs_embeds,
  105. nn.functional.pad(inputs_embeds[:, :-1], [0, 0, 1, 0, 0, 0], value=0.0),
  106. ],
  107. dim=2,
  108. )
  109. if self.trigram_input or self.embedding_size != self.hidden_size:
  110. inputs_embeds = self.embedding_transformation(inputs_embeds)
  111. # Add positional embeddings and token type embeddings, then layer
  112. # normalize and perform dropout.
  113. position_embeddings = self.position_embeddings(position_ids)
  114. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  115. embeddings = inputs_embeds + position_embeddings + token_type_embeddings
  116. embeddings = self.LayerNorm(embeddings)
  117. embeddings = self.dropout(embeddings)
  118. return embeddings
  119. # Copied from transformers.models.bert.modeling_bert.eager_attention_forward
  120. def eager_attention_forward(
  121. module: nn.Module,
  122. query: torch.Tensor,
  123. key: torch.Tensor,
  124. value: torch.Tensor,
  125. attention_mask: torch.Tensor | None,
  126. scaling: float | None = None,
  127. dropout: float = 0.0,
  128. **kwargs: Unpack[TransformersKwargs],
  129. ):
  130. if scaling is None:
  131. scaling = query.size(-1) ** -0.5
  132. # Take the dot product between "query" and "key" to get the raw attention scores.
  133. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  134. if attention_mask is not None:
  135. attn_weights = attn_weights + attention_mask
  136. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  137. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  138. attn_output = torch.matmul(attn_weights, value)
  139. attn_output = attn_output.transpose(1, 2).contiguous()
  140. return attn_output, attn_weights
  141. class MobileBertSelfAttention(nn.Module):
  142. def __init__(self, config):
  143. super().__init__()
  144. self.config = config
  145. self.num_attention_heads = config.num_attention_heads
  146. self.attention_head_size = int(config.true_hidden_size / config.num_attention_heads)
  147. self.all_head_size = self.num_attention_heads * self.attention_head_size
  148. self.scaling = self.attention_head_size**-0.5
  149. self.query = nn.Linear(config.true_hidden_size, self.all_head_size)
  150. self.key = nn.Linear(config.true_hidden_size, self.all_head_size)
  151. self.value = nn.Linear(
  152. config.true_hidden_size if config.use_bottleneck_attention else config.hidden_size, self.all_head_size
  153. )
  154. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  155. self.is_causal = False
  156. def forward(
  157. self,
  158. query_tensor: torch.Tensor,
  159. key_tensor: torch.Tensor,
  160. value_tensor: torch.Tensor,
  161. attention_mask: torch.FloatTensor | None = None,
  162. **kwargs: Unpack[TransformersKwargs],
  163. ) -> tuple[torch.Tensor]:
  164. input_shape = query_tensor.shape[:-1]
  165. hidden_shape = (*input_shape, -1, self.attention_head_size)
  166. # get all proj
  167. query_layer = self.query(query_tensor).view(*hidden_shape).transpose(1, 2)
  168. key_layer = self.key(key_tensor).view(*hidden_shape).transpose(1, 2)
  169. value_layer = self.value(value_tensor).view(*hidden_shape).transpose(1, 2)
  170. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  171. self.config._attn_implementation, eager_attention_forward
  172. )
  173. attn_output, attn_weights = attention_interface(
  174. self,
  175. query_layer,
  176. key_layer,
  177. value_layer,
  178. attention_mask,
  179. dropout=0.0 if not self.training else self.dropout.p,
  180. scaling=self.scaling,
  181. **kwargs,
  182. )
  183. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  184. return attn_output, attn_weights
  185. class MobileBertSelfOutput(nn.Module):
  186. def __init__(self, config):
  187. super().__init__()
  188. self.use_bottleneck = config.use_bottleneck
  189. self.dense = nn.Linear(config.true_hidden_size, config.true_hidden_size)
  190. self.LayerNorm = NORM2FN[config.normalization_type](config.true_hidden_size, eps=config.layer_norm_eps)
  191. if not self.use_bottleneck:
  192. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  193. def forward(self, hidden_states: torch.Tensor, residual_tensor: torch.Tensor) -> torch.Tensor:
  194. layer_outputs = self.dense(hidden_states)
  195. if not self.use_bottleneck:
  196. layer_outputs = self.dropout(layer_outputs)
  197. layer_outputs = self.LayerNorm(layer_outputs + residual_tensor)
  198. return layer_outputs
  199. class MobileBertAttention(nn.Module):
  200. def __init__(self, config):
  201. super().__init__()
  202. self.self = MobileBertSelfAttention(config)
  203. self.output = MobileBertSelfOutput(config)
  204. def forward(
  205. self,
  206. query_tensor: torch.Tensor,
  207. key_tensor: torch.Tensor,
  208. value_tensor: torch.Tensor,
  209. layer_input: torch.Tensor,
  210. attention_mask: torch.FloatTensor | None = None,
  211. **kwargs: Unpack[TransformersKwargs],
  212. ) -> tuple[torch.Tensor]:
  213. attention_output, attn_weights = self.self(
  214. query_tensor,
  215. key_tensor,
  216. value_tensor,
  217. attention_mask,
  218. **kwargs,
  219. )
  220. # Run a linear projection of `hidden_size` then add a residual
  221. # with `layer_input`.
  222. attention_output = self.output(attention_output, layer_input)
  223. return attention_output, attn_weights
  224. class MobileBertIntermediate(nn.Module):
  225. def __init__(self, config):
  226. super().__init__()
  227. self.dense = nn.Linear(config.true_hidden_size, config.intermediate_size)
  228. if isinstance(config.hidden_act, str):
  229. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  230. else:
  231. self.intermediate_act_fn = config.hidden_act
  232. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  233. hidden_states = self.dense(hidden_states)
  234. hidden_states = self.intermediate_act_fn(hidden_states)
  235. return hidden_states
  236. class OutputBottleneck(nn.Module):
  237. def __init__(self, config):
  238. super().__init__()
  239. self.dense = nn.Linear(config.true_hidden_size, config.hidden_size)
  240. self.LayerNorm = NORM2FN[config.normalization_type](config.hidden_size, eps=config.layer_norm_eps)
  241. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  242. def forward(self, hidden_states: torch.Tensor, residual_tensor: torch.Tensor) -> torch.Tensor:
  243. layer_outputs = self.dense(hidden_states)
  244. layer_outputs = self.dropout(layer_outputs)
  245. layer_outputs = self.LayerNorm(layer_outputs + residual_tensor)
  246. return layer_outputs
  247. class MobileBertOutput(nn.Module):
  248. def __init__(self, config):
  249. super().__init__()
  250. self.use_bottleneck = config.use_bottleneck
  251. self.dense = nn.Linear(config.intermediate_size, config.true_hidden_size)
  252. self.LayerNorm = NORM2FN[config.normalization_type](config.true_hidden_size)
  253. if not self.use_bottleneck:
  254. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  255. else:
  256. self.bottleneck = OutputBottleneck(config)
  257. def forward(
  258. self, intermediate_states: torch.Tensor, residual_tensor_1: torch.Tensor, residual_tensor_2: torch.Tensor
  259. ) -> torch.Tensor:
  260. layer_output = self.dense(intermediate_states)
  261. if not self.use_bottleneck:
  262. layer_output = self.dropout(layer_output)
  263. layer_output = self.LayerNorm(layer_output + residual_tensor_1)
  264. else:
  265. layer_output = self.LayerNorm(layer_output + residual_tensor_1)
  266. layer_output = self.bottleneck(layer_output, residual_tensor_2)
  267. return layer_output
  268. class BottleneckLayer(nn.Module):
  269. def __init__(self, config):
  270. super().__init__()
  271. self.dense = nn.Linear(config.hidden_size, config.intra_bottleneck_size)
  272. self.LayerNorm = NORM2FN[config.normalization_type](config.intra_bottleneck_size, eps=config.layer_norm_eps)
  273. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  274. layer_input = self.dense(hidden_states)
  275. layer_input = self.LayerNorm(layer_input)
  276. return layer_input
  277. class Bottleneck(nn.Module):
  278. def __init__(self, config):
  279. super().__init__()
  280. self.key_query_shared_bottleneck = config.key_query_shared_bottleneck
  281. self.use_bottleneck_attention = config.use_bottleneck_attention
  282. self.input = BottleneckLayer(config)
  283. if self.key_query_shared_bottleneck:
  284. self.attention = BottleneckLayer(config)
  285. def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor]:
  286. # This method can return three different tuples of values. These different values make use of bottlenecks,
  287. # which are linear layers used to project the hidden states to a lower-dimensional vector, reducing memory
  288. # usage. These linear layer have weights that are learned during training.
  289. #
  290. # If `config.use_bottleneck_attention`, it will return the result of the bottleneck layer four times for the
  291. # key, query, value, and "layer input" to be used by the attention layer.
  292. # This bottleneck is used to project the hidden. This last layer input will be used as a residual tensor
  293. # in the attention self output, after the attention scores have been computed.
  294. #
  295. # If not `config.use_bottleneck_attention` and `config.key_query_shared_bottleneck`, this will return
  296. # four values, three of which have been passed through a bottleneck: the query and key, passed through the same
  297. # bottleneck, and the residual layer to be applied in the attention self output, through another bottleneck.
  298. #
  299. # Finally, in the last case, the values for the query, key and values are the hidden states without bottleneck,
  300. # and the residual layer will be this value passed through a bottleneck.
  301. bottlenecked_hidden_states = self.input(hidden_states)
  302. if self.use_bottleneck_attention:
  303. return (bottlenecked_hidden_states,) * 4
  304. elif self.key_query_shared_bottleneck:
  305. shared_attention_input = self.attention(hidden_states)
  306. return (shared_attention_input, shared_attention_input, hidden_states, bottlenecked_hidden_states)
  307. else:
  308. return (hidden_states, hidden_states, hidden_states, bottlenecked_hidden_states)
  309. class FFNOutput(nn.Module):
  310. def __init__(self, config):
  311. super().__init__()
  312. self.dense = nn.Linear(config.intermediate_size, config.true_hidden_size)
  313. self.LayerNorm = NORM2FN[config.normalization_type](config.true_hidden_size, eps=config.layer_norm_eps)
  314. def forward(self, hidden_states: torch.Tensor, residual_tensor: torch.Tensor) -> torch.Tensor:
  315. layer_outputs = self.dense(hidden_states)
  316. layer_outputs = self.LayerNorm(layer_outputs + residual_tensor)
  317. return layer_outputs
  318. class FFNLayer(nn.Module):
  319. def __init__(self, config):
  320. super().__init__()
  321. self.intermediate = MobileBertIntermediate(config)
  322. self.output = FFNOutput(config)
  323. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  324. intermediate_output = self.intermediate(hidden_states)
  325. layer_outputs = self.output(intermediate_output, hidden_states)
  326. return layer_outputs
  327. class MobileBertLayer(GradientCheckpointingLayer):
  328. def __init__(self, config):
  329. super().__init__()
  330. self.use_bottleneck = config.use_bottleneck
  331. self.num_feedforward_networks = config.num_feedforward_networks
  332. self.attention = MobileBertAttention(config)
  333. self.intermediate = MobileBertIntermediate(config)
  334. self.output = MobileBertOutput(config)
  335. if self.use_bottleneck:
  336. self.bottleneck = Bottleneck(config)
  337. if config.num_feedforward_networks > 1:
  338. self.ffn = nn.ModuleList([FFNLayer(config) for _ in range(config.num_feedforward_networks - 1)])
  339. def forward(
  340. self,
  341. hidden_states: torch.Tensor,
  342. attention_mask: torch.FloatTensor | None = None,
  343. **kwargs: Unpack[TransformersKwargs],
  344. ) -> torch.Tensor:
  345. if self.use_bottleneck:
  346. query_tensor, key_tensor, value_tensor, layer_input = self.bottleneck(hidden_states)
  347. else:
  348. query_tensor, key_tensor, value_tensor, layer_input = [hidden_states] * 4
  349. self_attention_output, _ = self.attention(
  350. query_tensor,
  351. key_tensor,
  352. value_tensor,
  353. layer_input,
  354. attention_mask,
  355. **kwargs,
  356. )
  357. attention_output = self_attention_output
  358. if self.num_feedforward_networks != 1:
  359. for ffn_module in self.ffn:
  360. attention_output = ffn_module(attention_output)
  361. intermediate_output = self.intermediate(attention_output)
  362. layer_output = self.output(intermediate_output, attention_output, hidden_states)
  363. return layer_output
  364. class MobileBertEncoder(nn.Module):
  365. def __init__(self, config):
  366. super().__init__()
  367. self.layer = nn.ModuleList([MobileBertLayer(config) for _ in range(config.num_hidden_layers)])
  368. def forward(
  369. self,
  370. hidden_states: torch.Tensor,
  371. attention_mask: torch.FloatTensor | None = None,
  372. **kwargs: Unpack[TransformersKwargs],
  373. ) -> tuple | BaseModelOutput:
  374. for i, layer_module in enumerate(self.layer):
  375. hidden_states = layer_module(
  376. hidden_states,
  377. attention_mask,
  378. **kwargs,
  379. )
  380. return BaseModelOutput(last_hidden_state=hidden_states)
  381. class MobileBertPooler(nn.Module):
  382. def __init__(self, config):
  383. super().__init__()
  384. self.do_activate = config.classifier_activation
  385. if self.do_activate:
  386. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  387. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  388. # We "pool" the model by simply taking the hidden state corresponding
  389. # to the first token.
  390. first_token_tensor = hidden_states[:, 0]
  391. if not self.do_activate:
  392. return first_token_tensor
  393. else:
  394. pooled_output = self.dense(first_token_tensor)
  395. pooled_output = torch.tanh(pooled_output)
  396. return pooled_output
  397. class MobileBertPredictionHeadTransform(nn.Module):
  398. def __init__(self, config):
  399. super().__init__()
  400. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  401. if isinstance(config.hidden_act, str):
  402. self.transform_act_fn = ACT2FN[config.hidden_act]
  403. else:
  404. self.transform_act_fn = config.hidden_act
  405. self.LayerNorm = NORM2FN["layer_norm"](config.hidden_size, eps=config.layer_norm_eps)
  406. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  407. hidden_states = self.dense(hidden_states)
  408. hidden_states = self.transform_act_fn(hidden_states)
  409. hidden_states = self.LayerNorm(hidden_states)
  410. return hidden_states
  411. class MobileBertLMPredictionHead(nn.Module):
  412. def __init__(self, config):
  413. super().__init__()
  414. self.transform = MobileBertPredictionHeadTransform(config)
  415. # The output weights are the same as the input embeddings, but there is
  416. # an output-only bias for each token.
  417. self.dense = nn.Linear(config.vocab_size, config.hidden_size - config.embedding_size, bias=False)
  418. self.decoder = nn.Linear(config.embedding_size, config.vocab_size, bias=True)
  419. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  420. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  421. hidden_states = self.transform(hidden_states)
  422. hidden_states = hidden_states.matmul(torch.cat([self.decoder.weight.t(), self.dense.weight], dim=0))
  423. hidden_states += self.decoder.bias
  424. return hidden_states
  425. class MobileBertOnlyMLMHead(nn.Module):
  426. def __init__(self, config):
  427. super().__init__()
  428. self.predictions = MobileBertLMPredictionHead(config)
  429. def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
  430. prediction_scores = self.predictions(sequence_output)
  431. return prediction_scores
  432. class MobileBertPreTrainingHeads(nn.Module):
  433. def __init__(self, config):
  434. super().__init__()
  435. self.predictions = MobileBertLMPredictionHead(config)
  436. self.seq_relationship = nn.Linear(config.hidden_size, 2)
  437. def forward(self, sequence_output: torch.Tensor, pooled_output: torch.Tensor) -> tuple[torch.Tensor]:
  438. prediction_scores = self.predictions(sequence_output)
  439. seq_relationship_score = self.seq_relationship(pooled_output)
  440. return prediction_scores, seq_relationship_score
  441. @auto_docstring
  442. class MobileBertPreTrainedModel(PreTrainedModel):
  443. config: MobileBertConfig
  444. base_model_prefix = "mobilebert"
  445. supports_gradient_checkpointing = True
  446. _supports_flash_attn = True
  447. _supports_sdpa = True
  448. _supports_flex_attn = True
  449. _supports_attention_backend = True
  450. _can_record_outputs = {
  451. "hidden_states": MobileBertLayer,
  452. "attentions": MobileBertSelfAttention,
  453. }
  454. @torch.no_grad()
  455. def _init_weights(self, module):
  456. """Initialize the weights"""
  457. super()._init_weights(module)
  458. if isinstance(module, NoNorm):
  459. init.zeros_(module.bias)
  460. init.ones_(module.weight)
  461. elif isinstance(module, MobileBertLMPredictionHead):
  462. init.zeros_(module.bias)
  463. elif isinstance(module, MobileBertEmbeddings):
  464. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  465. @dataclass
  466. @auto_docstring(
  467. custom_intro="""
  468. Output type of [`MobileBertForPreTraining`].
  469. """
  470. )
  471. class MobileBertForPreTrainingOutput(ModelOutput):
  472. r"""
  473. loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  474. Total loss as the sum of the masked language modeling loss and the next sequence prediction
  475. (classification) loss.
  476. prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  477. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  478. seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
  479. Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
  480. before SoftMax).
  481. """
  482. loss: torch.FloatTensor | None = None
  483. prediction_logits: torch.FloatTensor | None = None
  484. seq_relationship_logits: torch.FloatTensor | None = None
  485. hidden_states: tuple[torch.FloatTensor] | None = None
  486. attentions: tuple[torch.FloatTensor] | None = None
  487. @auto_docstring
  488. class MobileBertModel(MobileBertPreTrainedModel):
  489. """
  490. https://huggingface.co/papers/2004.02984
  491. """
  492. def __init__(self, config, add_pooling_layer=True):
  493. r"""
  494. add_pooling_layer (bool, *optional*, defaults to `True`):
  495. Whether to add a pooling layer
  496. """
  497. super().__init__(config)
  498. self.config = config
  499. self.gradient_checkpointing = False
  500. self.embeddings = MobileBertEmbeddings(config)
  501. self.encoder = MobileBertEncoder(config)
  502. self.pooler = MobileBertPooler(config) if add_pooling_layer else None
  503. # Initialize weights and apply final processing
  504. self.post_init()
  505. def get_input_embeddings(self):
  506. return self.embeddings.word_embeddings
  507. def set_input_embeddings(self, value):
  508. self.embeddings.word_embeddings = value
  509. @merge_with_config_defaults
  510. @capture_outputs
  511. @auto_docstring
  512. def forward(
  513. self,
  514. input_ids: torch.LongTensor | None = None,
  515. attention_mask: torch.FloatTensor | None = None,
  516. token_type_ids: torch.LongTensor | None = None,
  517. position_ids: torch.LongTensor | None = None,
  518. inputs_embeds: torch.FloatTensor | None = None,
  519. **kwargs: Unpack[TransformersKwargs],
  520. ) -> tuple | BaseModelOutputWithPooling:
  521. if (input_ids is None) ^ (inputs_embeds is not None):
  522. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  523. embedding_output = self.embeddings(
  524. input_ids=input_ids,
  525. position_ids=position_ids,
  526. token_type_ids=token_type_ids,
  527. inputs_embeds=inputs_embeds,
  528. )
  529. attention_mask = create_bidirectional_mask(
  530. config=self.config,
  531. inputs_embeds=embedding_output,
  532. attention_mask=attention_mask,
  533. )
  534. encoder_outputs = self.encoder(
  535. embedding_output,
  536. attention_mask=attention_mask,
  537. **kwargs,
  538. )
  539. sequence_output = encoder_outputs[0]
  540. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  541. return BaseModelOutputWithPooling(
  542. last_hidden_state=sequence_output,
  543. pooler_output=pooled_output,
  544. )
  545. @auto_docstring(
  546. custom_intro="""
  547. MobileBert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
  548. `next sentence prediction (classification)` head.
  549. """
  550. )
  551. class MobileBertForPreTraining(MobileBertPreTrainedModel):
  552. _tied_weights_keys = {
  553. "cls.predictions.decoder.bias": "cls.predictions.bias",
  554. "cls.predictions.decoder.weight": "mobilebert.embeddings.word_embeddings.weight",
  555. }
  556. def __init__(self, config):
  557. super().__init__(config)
  558. self.mobilebert = MobileBertModel(config)
  559. self.cls = MobileBertPreTrainingHeads(config)
  560. # Initialize weights and apply final processing
  561. self.post_init()
  562. def get_output_embeddings(self):
  563. return self.cls.predictions.decoder
  564. def set_output_embeddings(self, new_embeddings):
  565. self.cls.predictions.decoder = new_embeddings
  566. self.cls.predictions.bias = new_embeddings.bias
  567. def resize_token_embeddings(self, new_num_tokens: int | None = None) -> nn.Embedding:
  568. # resize dense output embedings at first
  569. self.cls.predictions.dense = self._get_resized_lm_head(
  570. self.cls.predictions.dense, new_num_tokens=new_num_tokens, transposed=True
  571. )
  572. return super().resize_token_embeddings(new_num_tokens=new_num_tokens)
  573. @can_return_tuple
  574. @auto_docstring
  575. def forward(
  576. self,
  577. input_ids: torch.LongTensor | None = None,
  578. attention_mask: torch.FloatTensor | None = None,
  579. token_type_ids: torch.LongTensor | None = None,
  580. position_ids: torch.LongTensor | None = None,
  581. inputs_embeds: torch.FloatTensor | None = None,
  582. labels: torch.LongTensor | None = None,
  583. next_sentence_label: torch.LongTensor | None = None,
  584. **kwargs: Unpack[TransformersKwargs],
  585. ) -> tuple | MobileBertForPreTrainingOutput:
  586. r"""
  587. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  588. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  589. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  590. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  591. next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  592. Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
  593. (see `input_ids` docstring) Indices should be in `[0, 1]`:
  594. - 0 indicates sequence B is a continuation of sequence A,
  595. - 1 indicates sequence B is a random sequence.
  596. Examples:
  597. ```python
  598. >>> from transformers import AutoTokenizer, MobileBertForPreTraining
  599. >>> import torch
  600. >>> tokenizer = AutoTokenizer.from_pretrained("google/mobilebert-uncased")
  601. >>> model = MobileBertForPreTraining.from_pretrained("google/mobilebert-uncased")
  602. >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)
  603. >>> # Batch size 1
  604. >>> outputs = model(input_ids)
  605. >>> prediction_logits = outputs.prediction_logits
  606. >>> seq_relationship_logits = outputs.seq_relationship_logits
  607. ```"""
  608. outputs = self.mobilebert(
  609. input_ids,
  610. attention_mask=attention_mask,
  611. token_type_ids=token_type_ids,
  612. position_ids=position_ids,
  613. inputs_embeds=inputs_embeds,
  614. return_dict=True,
  615. **kwargs,
  616. )
  617. sequence_output, pooled_output = outputs[:2]
  618. prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
  619. total_loss = None
  620. if labels is not None and next_sentence_label is not None:
  621. loss_fct = CrossEntropyLoss()
  622. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  623. next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
  624. total_loss = masked_lm_loss + next_sentence_loss
  625. return MobileBertForPreTrainingOutput(
  626. loss=total_loss,
  627. prediction_logits=prediction_scores,
  628. seq_relationship_logits=seq_relationship_score,
  629. hidden_states=outputs.hidden_states,
  630. attentions=outputs.attentions,
  631. )
  632. @auto_docstring
  633. class MobileBertForMaskedLM(MobileBertPreTrainedModel):
  634. _tied_weights_keys = {
  635. "cls.predictions.decoder.bias": "cls.predictions.bias",
  636. "cls.predictions.decoder.weight": "mobilebert.embeddings.word_embeddings.weight",
  637. }
  638. def __init__(self, config):
  639. super().__init__(config)
  640. self.mobilebert = MobileBertModel(config, add_pooling_layer=False)
  641. self.cls = MobileBertOnlyMLMHead(config)
  642. self.config = config
  643. # Initialize weights and apply final processing
  644. self.post_init()
  645. def get_output_embeddings(self):
  646. return self.cls.predictions.decoder
  647. def set_output_embeddings(self, new_embeddings):
  648. self.cls.predictions.decoder = new_embeddings
  649. self.cls.predictions.bias = new_embeddings.bias
  650. def resize_token_embeddings(self, new_num_tokens: int | None = None) -> nn.Embedding:
  651. # resize dense output embedings at first
  652. self.cls.predictions.dense = self._get_resized_lm_head(
  653. self.cls.predictions.dense, new_num_tokens=new_num_tokens, transposed=True
  654. )
  655. return super().resize_token_embeddings(new_num_tokens=new_num_tokens)
  656. @can_return_tuple
  657. @auto_docstring
  658. def forward(
  659. self,
  660. input_ids: torch.LongTensor | None = None,
  661. attention_mask: torch.FloatTensor | None = None,
  662. token_type_ids: torch.LongTensor | None = None,
  663. position_ids: torch.LongTensor | None = None,
  664. inputs_embeds: torch.FloatTensor | None = None,
  665. labels: torch.LongTensor | None = None,
  666. **kwargs: Unpack[TransformersKwargs],
  667. ) -> tuple | MaskedLMOutput:
  668. r"""
  669. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  670. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  671. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  672. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  673. """
  674. outputs = self.mobilebert(
  675. input_ids,
  676. attention_mask=attention_mask,
  677. token_type_ids=token_type_ids,
  678. position_ids=position_ids,
  679. inputs_embeds=inputs_embeds,
  680. return_dict=True,
  681. **kwargs,
  682. )
  683. sequence_output = outputs[0]
  684. prediction_scores = self.cls(sequence_output)
  685. masked_lm_loss = None
  686. if labels is not None:
  687. loss_fct = CrossEntropyLoss() # -100 index = padding token
  688. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  689. return MaskedLMOutput(
  690. loss=masked_lm_loss,
  691. logits=prediction_scores,
  692. hidden_states=outputs.hidden_states,
  693. attentions=outputs.attentions,
  694. )
  695. class MobileBertOnlyNSPHead(nn.Module):
  696. def __init__(self, config):
  697. super().__init__()
  698. self.seq_relationship = nn.Linear(config.hidden_size, 2)
  699. def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
  700. seq_relationship_score = self.seq_relationship(pooled_output)
  701. return seq_relationship_score
  702. @auto_docstring(
  703. custom_intro="""
  704. MobileBert Model with a `next sentence prediction (classification)` head on top.
  705. """
  706. )
  707. class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
  708. def __init__(self, config):
  709. super().__init__(config)
  710. self.mobilebert = MobileBertModel(config)
  711. self.cls = MobileBertOnlyNSPHead(config)
  712. # Initialize weights and apply final processing
  713. self.post_init()
  714. @can_return_tuple
  715. @auto_docstring
  716. def forward(
  717. self,
  718. input_ids: torch.LongTensor | None = None,
  719. attention_mask: torch.FloatTensor | None = None,
  720. token_type_ids: torch.LongTensor | None = None,
  721. position_ids: torch.LongTensor | None = None,
  722. inputs_embeds: torch.FloatTensor | None = None,
  723. labels: torch.LongTensor | None = None,
  724. **kwargs: Unpack[TransformersKwargs],
  725. ) -> tuple | NextSentencePredictorOutput:
  726. r"""
  727. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  728. Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
  729. (see `input_ids` docstring) Indices should be in `[0, 1]`.
  730. - 0 indicates sequence B is a continuation of sequence A,
  731. - 1 indicates sequence B is a random sequence.
  732. Examples:
  733. ```python
  734. >>> from transformers import AutoTokenizer, MobileBertForNextSentencePrediction
  735. >>> import torch
  736. >>> tokenizer = AutoTokenizer.from_pretrained("google/mobilebert-uncased")
  737. >>> model = MobileBertForNextSentencePrediction.from_pretrained("google/mobilebert-uncased")
  738. >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
  739. >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
  740. >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
  741. >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
  742. >>> loss = outputs.loss
  743. >>> logits = outputs.logits
  744. ```"""
  745. outputs = self.mobilebert(
  746. input_ids,
  747. attention_mask=attention_mask,
  748. token_type_ids=token_type_ids,
  749. position_ids=position_ids,
  750. inputs_embeds=inputs_embeds,
  751. return_dict=True,
  752. **kwargs,
  753. )
  754. pooled_output = outputs[1]
  755. seq_relationship_score = self.cls(pooled_output)
  756. next_sentence_loss = None
  757. if labels is not None:
  758. loss_fct = CrossEntropyLoss()
  759. next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), labels.view(-1))
  760. return NextSentencePredictorOutput(
  761. loss=next_sentence_loss,
  762. logits=seq_relationship_score,
  763. hidden_states=outputs.hidden_states,
  764. attentions=outputs.attentions,
  765. )
  766. @auto_docstring(
  767. custom_intro="""
  768. MobileBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the
  769. pooled output) e.g. for GLUE tasks.
  770. """
  771. )
  772. # Copied from transformers.models.bert.modeling_bert.BertForSequenceClassification with Bert->MobileBert all-casing
  773. class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
  774. def __init__(self, config):
  775. super().__init__(config)
  776. self.num_labels = config.num_labels
  777. self.config = config
  778. self.mobilebert = MobileBertModel(config)
  779. classifier_dropout = (
  780. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  781. )
  782. self.dropout = nn.Dropout(classifier_dropout)
  783. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  784. # Initialize weights and apply final processing
  785. self.post_init()
  786. @can_return_tuple
  787. @auto_docstring
  788. def forward(
  789. self,
  790. input_ids: torch.Tensor | None = None,
  791. attention_mask: torch.Tensor | None = None,
  792. token_type_ids: torch.Tensor | None = None,
  793. position_ids: torch.Tensor | None = None,
  794. inputs_embeds: torch.Tensor | None = None,
  795. labels: torch.Tensor | None = None,
  796. **kwargs: Unpack[TransformersKwargs],
  797. ) -> tuple[torch.Tensor] | SequenceClassifierOutput:
  798. r"""
  799. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  800. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  801. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  802. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  803. """
  804. outputs = self.mobilebert(
  805. input_ids,
  806. attention_mask=attention_mask,
  807. token_type_ids=token_type_ids,
  808. position_ids=position_ids,
  809. inputs_embeds=inputs_embeds,
  810. return_dict=True,
  811. **kwargs,
  812. )
  813. pooled_output = outputs[1]
  814. pooled_output = self.dropout(pooled_output)
  815. logits = self.classifier(pooled_output)
  816. loss = None
  817. if labels is not None:
  818. if self.config.problem_type is None:
  819. if self.num_labels == 1:
  820. self.config.problem_type = "regression"
  821. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  822. self.config.problem_type = "single_label_classification"
  823. else:
  824. self.config.problem_type = "multi_label_classification"
  825. if self.config.problem_type == "regression":
  826. loss_fct = MSELoss()
  827. if self.num_labels == 1:
  828. loss = loss_fct(logits.squeeze(), labels.squeeze())
  829. else:
  830. loss = loss_fct(logits, labels)
  831. elif self.config.problem_type == "single_label_classification":
  832. loss_fct = CrossEntropyLoss()
  833. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  834. elif self.config.problem_type == "multi_label_classification":
  835. loss_fct = BCEWithLogitsLoss()
  836. loss = loss_fct(logits, labels)
  837. return SequenceClassifierOutput(
  838. loss=loss,
  839. logits=logits,
  840. hidden_states=outputs.hidden_states,
  841. attentions=outputs.attentions,
  842. )
  843. @auto_docstring
  844. # Copied from transformers.models.bert.modeling_bert.BertForQuestionAnswering with Bert->MobileBert all-casing
  845. class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
  846. def __init__(self, config):
  847. super().__init__(config)
  848. self.num_labels = config.num_labels
  849. self.mobilebert = MobileBertModel(config, add_pooling_layer=False)
  850. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  851. # Initialize weights and apply final processing
  852. self.post_init()
  853. @can_return_tuple
  854. @auto_docstring
  855. def forward(
  856. self,
  857. input_ids: torch.Tensor | None = None,
  858. attention_mask: torch.Tensor | None = None,
  859. token_type_ids: torch.Tensor | None = None,
  860. position_ids: torch.Tensor | None = None,
  861. inputs_embeds: torch.Tensor | None = None,
  862. start_positions: torch.Tensor | None = None,
  863. end_positions: torch.Tensor | None = None,
  864. **kwargs: Unpack[TransformersKwargs],
  865. ) -> tuple[torch.Tensor] | QuestionAnsweringModelOutput:
  866. outputs = self.mobilebert(
  867. input_ids,
  868. attention_mask=attention_mask,
  869. token_type_ids=token_type_ids,
  870. position_ids=position_ids,
  871. inputs_embeds=inputs_embeds,
  872. return_dict=True,
  873. **kwargs,
  874. )
  875. sequence_output = outputs[0]
  876. logits = self.qa_outputs(sequence_output)
  877. start_logits, end_logits = logits.split(1, dim=-1)
  878. start_logits = start_logits.squeeze(-1).contiguous()
  879. end_logits = end_logits.squeeze(-1).contiguous()
  880. total_loss = None
  881. if start_positions is not None and end_positions is not None:
  882. # If we are on multi-GPU, split add a dimension
  883. if len(start_positions.size()) > 1:
  884. start_positions = start_positions.squeeze(-1)
  885. if len(end_positions.size()) > 1:
  886. end_positions = end_positions.squeeze(-1)
  887. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  888. ignored_index = start_logits.size(1)
  889. start_positions = start_positions.clamp(0, ignored_index)
  890. end_positions = end_positions.clamp(0, ignored_index)
  891. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  892. start_loss = loss_fct(start_logits, start_positions)
  893. end_loss = loss_fct(end_logits, end_positions)
  894. total_loss = (start_loss + end_loss) / 2
  895. return QuestionAnsweringModelOutput(
  896. loss=total_loss,
  897. start_logits=start_logits,
  898. end_logits=end_logits,
  899. hidden_states=outputs.hidden_states,
  900. attentions=outputs.attentions,
  901. )
  902. @auto_docstring
  903. # Copied from transformers.models.bert.modeling_bert.BertForMultipleChoice with Bert->MobileBert all-casing
  904. class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
  905. def __init__(self, config):
  906. super().__init__(config)
  907. self.mobilebert = MobileBertModel(config)
  908. classifier_dropout = (
  909. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  910. )
  911. self.dropout = nn.Dropout(classifier_dropout)
  912. self.classifier = nn.Linear(config.hidden_size, 1)
  913. # Initialize weights and apply final processing
  914. self.post_init()
  915. @can_return_tuple
  916. @auto_docstring
  917. def forward(
  918. self,
  919. input_ids: torch.Tensor | None = None,
  920. attention_mask: torch.Tensor | None = None,
  921. token_type_ids: torch.Tensor | None = None,
  922. position_ids: torch.Tensor | None = None,
  923. inputs_embeds: torch.Tensor | None = None,
  924. labels: torch.Tensor | None = None,
  925. **kwargs: Unpack[TransformersKwargs],
  926. ) -> tuple[torch.Tensor] | MultipleChoiceModelOutput:
  927. r"""
  928. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  929. Indices of input sequence tokens in the vocabulary.
  930. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  931. [`PreTrainedTokenizer.__call__`] for details.
  932. [What are input IDs?](../glossary#input-ids)
  933. token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  934. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  935. 1]`:
  936. - 0 corresponds to a *sentence A* token,
  937. - 1 corresponds to a *sentence B* token.
  938. [What are token type IDs?](../glossary#token-type-ids)
  939. position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  940. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  941. config.max_position_embeddings - 1]`.
  942. [What are position IDs?](../glossary#position-ids)
  943. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  944. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  945. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  946. model's internal embedding lookup matrix.
  947. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  948. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  949. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  950. `input_ids` above)
  951. """
  952. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  953. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  954. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  955. token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  956. position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  957. inputs_embeds = (
  958. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  959. if inputs_embeds is not None
  960. else None
  961. )
  962. outputs = self.mobilebert(
  963. input_ids,
  964. attention_mask=attention_mask,
  965. token_type_ids=token_type_ids,
  966. position_ids=position_ids,
  967. inputs_embeds=inputs_embeds,
  968. return_dict=True,
  969. **kwargs,
  970. )
  971. pooled_output = outputs[1]
  972. pooled_output = self.dropout(pooled_output)
  973. logits = self.classifier(pooled_output)
  974. reshaped_logits = logits.view(-1, num_choices)
  975. loss = None
  976. if labels is not None:
  977. loss_fct = CrossEntropyLoss()
  978. loss = loss_fct(reshaped_logits, labels)
  979. return MultipleChoiceModelOutput(
  980. loss=loss,
  981. logits=reshaped_logits,
  982. hidden_states=outputs.hidden_states,
  983. attentions=outputs.attentions,
  984. )
  985. @auto_docstring
  986. # Copied from transformers.models.bert.modeling_bert.BertForTokenClassification with Bert->MobileBert all-casing
  987. class MobileBertForTokenClassification(MobileBertPreTrainedModel):
  988. def __init__(self, config):
  989. super().__init__(config)
  990. self.num_labels = config.num_labels
  991. self.mobilebert = MobileBertModel(config, add_pooling_layer=False)
  992. classifier_dropout = (
  993. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  994. )
  995. self.dropout = nn.Dropout(classifier_dropout)
  996. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  997. # Initialize weights and apply final processing
  998. self.post_init()
  999. @can_return_tuple
  1000. @auto_docstring
  1001. def forward(
  1002. self,
  1003. input_ids: torch.Tensor | None = None,
  1004. attention_mask: torch.Tensor | None = None,
  1005. token_type_ids: torch.Tensor | None = None,
  1006. position_ids: torch.Tensor | None = None,
  1007. inputs_embeds: torch.Tensor | None = None,
  1008. labels: torch.Tensor | None = None,
  1009. **kwargs: Unpack[TransformersKwargs],
  1010. ) -> tuple[torch.Tensor] | TokenClassifierOutput:
  1011. r"""
  1012. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1013. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1014. """
  1015. outputs = self.mobilebert(
  1016. input_ids,
  1017. attention_mask=attention_mask,
  1018. token_type_ids=token_type_ids,
  1019. position_ids=position_ids,
  1020. inputs_embeds=inputs_embeds,
  1021. return_dict=True,
  1022. **kwargs,
  1023. )
  1024. sequence_output = outputs[0]
  1025. sequence_output = self.dropout(sequence_output)
  1026. logits = self.classifier(sequence_output)
  1027. loss = None
  1028. if labels is not None:
  1029. loss_fct = CrossEntropyLoss()
  1030. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1031. return TokenClassifierOutput(
  1032. loss=loss,
  1033. logits=logits,
  1034. hidden_states=outputs.hidden_states,
  1035. attentions=outputs.attentions,
  1036. )
  1037. __all__ = [
  1038. "MobileBertForMaskedLM",
  1039. "MobileBertForMultipleChoice",
  1040. "MobileBertForNextSentencePrediction",
  1041. "MobileBertForPreTraining",
  1042. "MobileBertForQuestionAnswering",
  1043. "MobileBertForSequenceClassification",
  1044. "MobileBertForTokenClassification",
  1045. "MobileBertLayer",
  1046. "MobileBertModel",
  1047. "MobileBertPreTrainedModel",
  1048. ]