modeling_squeezebert.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920
  1. # Copyright 2020 The SqueezeBert authors and The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch SqueezeBert model."""
  15. import math
  16. import torch
  17. from torch import nn
  18. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  19. from ... import initialization as init
  20. from ...activations import ACT2FN
  21. from ...modeling_outputs import (
  22. BaseModelOutput,
  23. BaseModelOutputWithPooling,
  24. MaskedLMOutput,
  25. MultipleChoiceModelOutput,
  26. QuestionAnsweringModelOutput,
  27. SequenceClassifierOutput,
  28. TokenClassifierOutput,
  29. )
  30. from ...modeling_utils import PreTrainedModel
  31. from ...utils import (
  32. auto_docstring,
  33. logging,
  34. )
  35. from .configuration_squeezebert import SqueezeBertConfig
  36. logger = logging.get_logger(__name__)
  37. class SqueezeBertEmbeddings(nn.Module):
  38. """Construct the embeddings from word, position and token_type embeddings."""
  39. def __init__(self, config):
  40. super().__init__()
  41. self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
  42. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
  43. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
  44. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  45. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  46. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  47. self.register_buffer(
  48. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  49. )
  50. def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
  51. if input_ids is not None:
  52. input_shape = input_ids.size()
  53. else:
  54. input_shape = inputs_embeds.size()[:-1]
  55. seq_length = input_shape[1]
  56. if position_ids is None:
  57. position_ids = self.position_ids[:, :seq_length]
  58. if token_type_ids is None:
  59. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  60. if inputs_embeds is None:
  61. inputs_embeds = self.word_embeddings(input_ids)
  62. position_embeddings = self.position_embeddings(position_ids)
  63. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  64. embeddings = inputs_embeds + position_embeddings + token_type_embeddings
  65. embeddings = self.LayerNorm(embeddings)
  66. embeddings = self.dropout(embeddings)
  67. return embeddings
  68. class MatMulWrapper(nn.Module):
  69. """
  70. Wrapper for torch.matmul(). This makes flop-counting easier to implement. Note that if you directly call
  71. torch.matmul() in your code, the flop counter will typically ignore the flops of the matmul.
  72. """
  73. def __init__(self):
  74. super().__init__()
  75. def forward(self, mat1, mat2):
  76. """
  77. :param inputs: two torch tensors :return: matmul of these tensors
  78. Here are the typical dimensions found in BERT (the B is optional) mat1.shape: [B, <optional extra dims>, M, K]
  79. mat2.shape: [B, <optional extra dims>, K, N] output shape: [B, <optional extra dims>, M, N]
  80. """
  81. return torch.matmul(mat1, mat2)
  82. class SqueezeBertLayerNorm(nn.LayerNorm):
  83. """
  84. This is a nn.LayerNorm subclass that accepts NCW data layout and performs normalization in the C dimension.
  85. N = batch C = channels W = sequence length
  86. """
  87. def __init__(self, hidden_size, eps=1e-12):
  88. nn.LayerNorm.__init__(self, normalized_shape=hidden_size, eps=eps) # instantiates self.{weight, bias, eps}
  89. def forward(self, x):
  90. x = x.permute(0, 2, 1)
  91. x = nn.LayerNorm.forward(self, x)
  92. return x.permute(0, 2, 1)
  93. class ConvDropoutLayerNorm(nn.Module):
  94. """
  95. ConvDropoutLayerNorm: Conv, Dropout, LayerNorm
  96. """
  97. def __init__(self, cin, cout, groups, dropout_prob):
  98. super().__init__()
  99. self.conv1d = nn.Conv1d(in_channels=cin, out_channels=cout, kernel_size=1, groups=groups)
  100. self.layernorm = SqueezeBertLayerNorm(cout)
  101. self.dropout = nn.Dropout(dropout_prob)
  102. def forward(self, hidden_states, input_tensor):
  103. x = self.conv1d(hidden_states)
  104. x = self.dropout(x)
  105. x = x + input_tensor
  106. x = self.layernorm(x)
  107. return x
  108. class ConvActivation(nn.Module):
  109. """
  110. ConvActivation: Conv, Activation
  111. """
  112. def __init__(self, cin, cout, groups, act):
  113. super().__init__()
  114. self.conv1d = nn.Conv1d(in_channels=cin, out_channels=cout, kernel_size=1, groups=groups)
  115. self.act = ACT2FN[act]
  116. def forward(self, x):
  117. output = self.conv1d(x)
  118. return self.act(output)
  119. class SqueezeBertSelfAttention(nn.Module):
  120. def __init__(self, config, cin, q_groups=1, k_groups=1, v_groups=1):
  121. """
  122. config = used for some things; ignored for others (work in progress...) cin = input channels = output channels
  123. groups = number of groups to use in conv1d layers
  124. """
  125. super().__init__()
  126. if cin % config.num_attention_heads != 0:
  127. raise ValueError(
  128. f"cin ({cin}) is not a multiple of the number of attention heads ({config.num_attention_heads})"
  129. )
  130. self.num_attention_heads = config.num_attention_heads
  131. self.attention_head_size = int(cin / config.num_attention_heads)
  132. self.all_head_size = self.num_attention_heads * self.attention_head_size
  133. self.query = nn.Conv1d(in_channels=cin, out_channels=cin, kernel_size=1, groups=q_groups)
  134. self.key = nn.Conv1d(in_channels=cin, out_channels=cin, kernel_size=1, groups=k_groups)
  135. self.value = nn.Conv1d(in_channels=cin, out_channels=cin, kernel_size=1, groups=v_groups)
  136. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  137. self.softmax = nn.Softmax(dim=-1)
  138. self.matmul_qk = MatMulWrapper()
  139. self.matmul_qkv = MatMulWrapper()
  140. def transpose_for_scores(self, x):
  141. """
  142. - input: [N, C, W]
  143. - output: [N, C1, W, C2] where C1 is the head index, and C2 is one head's contents
  144. """
  145. new_x_shape = (x.size()[0], self.num_attention_heads, self.attention_head_size, x.size()[-1]) # [N, C1, C2, W]
  146. x = x.view(*new_x_shape)
  147. return x.permute(0, 1, 3, 2) # [N, C1, C2, W] --> [N, C1, W, C2]
  148. def transpose_key_for_scores(self, x):
  149. """
  150. - input: [N, C, W]
  151. - output: [N, C1, C2, W] where C1 is the head index, and C2 is one head's contents
  152. """
  153. new_x_shape = (x.size()[0], self.num_attention_heads, self.attention_head_size, x.size()[-1]) # [N, C1, C2, W]
  154. x = x.view(*new_x_shape)
  155. # no `permute` needed
  156. return x
  157. def transpose_output(self, x):
  158. """
  159. - input: [N, C1, W, C2]
  160. - output: [N, C, W]
  161. """
  162. x = x.permute(0, 1, 3, 2).contiguous() # [N, C1, C2, W]
  163. new_x_shape = (x.size()[0], self.all_head_size, x.size()[3]) # [N, C, W]
  164. x = x.view(*new_x_shape)
  165. return x
  166. def forward(self, hidden_states, attention_mask, output_attentions):
  167. """
  168. expects hidden_states in [N, C, W] data layout.
  169. The attention_mask data layout is [N, W], and it does not need to be transposed.
  170. """
  171. mixed_query_layer = self.query(hidden_states)
  172. mixed_key_layer = self.key(hidden_states)
  173. mixed_value_layer = self.value(hidden_states)
  174. query_layer = self.transpose_for_scores(mixed_query_layer)
  175. key_layer = self.transpose_key_for_scores(mixed_key_layer)
  176. value_layer = self.transpose_for_scores(mixed_value_layer)
  177. # Take the dot product between "query" and "key" to get the raw attention scores.
  178. attention_score = self.matmul_qk(query_layer, key_layer)
  179. attention_score = attention_score / math.sqrt(self.attention_head_size)
  180. # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
  181. attention_score = attention_score + attention_mask
  182. # Normalize the attention scores to probabilities.
  183. attention_probs = self.softmax(attention_score)
  184. # This is actually dropping out entire tokens to attend to, which might
  185. # seem a bit unusual, but is taken from the original Transformer paper.
  186. attention_probs = self.dropout(attention_probs)
  187. context_layer = self.matmul_qkv(attention_probs, value_layer)
  188. context_layer = self.transpose_output(context_layer)
  189. result = {"context_layer": context_layer}
  190. if output_attentions:
  191. result["attention_score"] = attention_score
  192. return result
  193. class SqueezeBertModule(nn.Module):
  194. def __init__(self, config):
  195. """
  196. - hidden_size = input chans = output chans for Q, K, V (they are all the same ... for now) = output chans for
  197. the module
  198. - intermediate_size = output chans for intermediate layer
  199. - groups = number of groups for all layers in the BertModule. (eventually we could change the interface to
  200. allow different groups for different layers)
  201. """
  202. super().__init__()
  203. c0 = config.hidden_size
  204. c1 = config.hidden_size
  205. c2 = config.intermediate_size
  206. c3 = config.hidden_size
  207. self.attention = SqueezeBertSelfAttention(
  208. config=config, cin=c0, q_groups=config.q_groups, k_groups=config.k_groups, v_groups=config.v_groups
  209. )
  210. self.post_attention = ConvDropoutLayerNorm(
  211. cin=c0, cout=c1, groups=config.post_attention_groups, dropout_prob=config.hidden_dropout_prob
  212. )
  213. self.intermediate = ConvActivation(cin=c1, cout=c2, groups=config.intermediate_groups, act=config.hidden_act)
  214. self.output = ConvDropoutLayerNorm(
  215. cin=c2, cout=c3, groups=config.output_groups, dropout_prob=config.hidden_dropout_prob
  216. )
  217. def forward(self, hidden_states, attention_mask, output_attentions):
  218. att = self.attention(hidden_states, attention_mask, output_attentions)
  219. attention_output = att["context_layer"]
  220. post_attention_output = self.post_attention(attention_output, hidden_states)
  221. intermediate_output = self.intermediate(post_attention_output)
  222. layer_output = self.output(intermediate_output, post_attention_output)
  223. output_dict = {"feature_map": layer_output}
  224. if output_attentions:
  225. output_dict["attention_score"] = att["attention_score"]
  226. return output_dict
  227. class SqueezeBertEncoder(nn.Module):
  228. def __init__(self, config):
  229. super().__init__()
  230. assert config.embedding_size == config.hidden_size, (
  231. "If you want embedding_size != intermediate hidden_size, "
  232. "please insert a Conv1d layer to adjust the number of channels "
  233. "before the first SqueezeBertModule."
  234. )
  235. self.layers = nn.ModuleList(SqueezeBertModule(config) for _ in range(config.num_hidden_layers))
  236. def forward(
  237. self,
  238. hidden_states,
  239. attention_mask=None,
  240. output_attentions=False,
  241. output_hidden_states=False,
  242. return_dict=True,
  243. ):
  244. # [batch_size, sequence_length, hidden_size] --> [batch_size, hidden_size, sequence_length]
  245. hidden_states = hidden_states.permute(0, 2, 1)
  246. all_hidden_states = () if output_hidden_states else None
  247. all_attentions = () if output_attentions else None
  248. for layer in self.layers:
  249. if output_hidden_states:
  250. hidden_states = hidden_states.permute(0, 2, 1)
  251. all_hidden_states += (hidden_states,)
  252. hidden_states = hidden_states.permute(0, 2, 1)
  253. layer_output = layer.forward(hidden_states, attention_mask, output_attentions)
  254. hidden_states = layer_output["feature_map"]
  255. if output_attentions:
  256. all_attentions += (layer_output["attention_score"],)
  257. # [batch_size, hidden_size, sequence_length] --> [batch_size, sequence_length, hidden_size]
  258. hidden_states = hidden_states.permute(0, 2, 1)
  259. if output_hidden_states:
  260. all_hidden_states += (hidden_states,)
  261. if not return_dict:
  262. return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
  263. return BaseModelOutput(
  264. last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
  265. )
  266. class SqueezeBertPooler(nn.Module):
  267. def __init__(self, config):
  268. super().__init__()
  269. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  270. self.activation = nn.Tanh()
  271. def forward(self, hidden_states):
  272. # We "pool" the model by simply taking the hidden state corresponding
  273. # to the first token.
  274. first_token_tensor = hidden_states[:, 0]
  275. pooled_output = self.dense(first_token_tensor)
  276. pooled_output = self.activation(pooled_output)
  277. return pooled_output
  278. class SqueezeBertPredictionHeadTransform(nn.Module):
  279. def __init__(self, config):
  280. super().__init__()
  281. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  282. if isinstance(config.hidden_act, str):
  283. self.transform_act_fn = ACT2FN[config.hidden_act]
  284. else:
  285. self.transform_act_fn = config.hidden_act
  286. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  287. def forward(self, hidden_states):
  288. hidden_states = self.dense(hidden_states)
  289. hidden_states = self.transform_act_fn(hidden_states)
  290. hidden_states = self.LayerNorm(hidden_states)
  291. return hidden_states
  292. class SqueezeBertLMPredictionHead(nn.Module):
  293. def __init__(self, config):
  294. super().__init__()
  295. self.transform = SqueezeBertPredictionHeadTransform(config)
  296. # The output weights are the same as the input embeddings, but there is
  297. # an output-only bias for each token.
  298. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
  299. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  300. # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
  301. def forward(self, hidden_states):
  302. hidden_states = self.transform(hidden_states)
  303. hidden_states = self.decoder(hidden_states)
  304. return hidden_states
  305. class SqueezeBertOnlyMLMHead(nn.Module):
  306. def __init__(self, config):
  307. super().__init__()
  308. self.predictions = SqueezeBertLMPredictionHead(config)
  309. def forward(self, sequence_output):
  310. prediction_scores = self.predictions(sequence_output)
  311. return prediction_scores
  312. @auto_docstring
  313. class SqueezeBertPreTrainedModel(PreTrainedModel):
  314. config: SqueezeBertConfig
  315. base_model_prefix = "transformer"
  316. @torch.no_grad()
  317. def _init_weights(self, module):
  318. """Initialize the weights"""
  319. super()._init_weights(module)
  320. if isinstance(module, SqueezeBertLMPredictionHead):
  321. init.zeros_(module.bias)
  322. elif isinstance(module, SqueezeBertEmbeddings):
  323. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  324. @auto_docstring
  325. class SqueezeBertModel(SqueezeBertPreTrainedModel):
  326. def __init__(self, config):
  327. super().__init__(config)
  328. self.embeddings = SqueezeBertEmbeddings(config)
  329. self.encoder = SqueezeBertEncoder(config)
  330. self.pooler = SqueezeBertPooler(config)
  331. # Initialize weights and apply final processing
  332. self.post_init()
  333. def get_input_embeddings(self):
  334. return self.embeddings.word_embeddings
  335. def set_input_embeddings(self, new_embeddings):
  336. self.embeddings.word_embeddings = new_embeddings
  337. @auto_docstring
  338. def forward(
  339. self,
  340. input_ids: torch.Tensor | None = None,
  341. attention_mask: torch.Tensor | None = None,
  342. token_type_ids: torch.Tensor | None = None,
  343. position_ids: torch.Tensor | None = None,
  344. inputs_embeds: torch.FloatTensor | None = None,
  345. output_attentions: bool | None = None,
  346. output_hidden_states: bool | None = None,
  347. return_dict: bool | None = None,
  348. **kwargs,
  349. ) -> tuple | BaseModelOutputWithPooling:
  350. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  351. output_hidden_states = (
  352. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  353. )
  354. return_dict = return_dict if return_dict is not None else self.config.return_dict
  355. if input_ids is not None and inputs_embeds is not None:
  356. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  357. elif input_ids is not None:
  358. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  359. input_shape = input_ids.size()
  360. elif inputs_embeds is not None:
  361. input_shape = inputs_embeds.size()[:-1]
  362. else:
  363. raise ValueError("You have to specify either input_ids or inputs_embeds")
  364. device = input_ids.device if input_ids is not None else inputs_embeds.device
  365. if attention_mask is None:
  366. attention_mask = torch.ones(input_shape, device=device)
  367. if token_type_ids is None:
  368. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  369. extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
  370. embedding_output = self.embeddings(
  371. input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
  372. )
  373. encoder_outputs = self.encoder(
  374. hidden_states=embedding_output,
  375. attention_mask=extended_attention_mask,
  376. output_attentions=output_attentions,
  377. output_hidden_states=output_hidden_states,
  378. return_dict=return_dict,
  379. )
  380. sequence_output = encoder_outputs[0]
  381. pooled_output = self.pooler(sequence_output)
  382. if not return_dict:
  383. return (sequence_output, pooled_output) + encoder_outputs[1:]
  384. return BaseModelOutputWithPooling(
  385. last_hidden_state=sequence_output,
  386. pooler_output=pooled_output,
  387. hidden_states=encoder_outputs.hidden_states,
  388. attentions=encoder_outputs.attentions,
  389. )
  390. @auto_docstring
  391. class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel):
  392. _tied_weights_keys = {
  393. "cls.predictions.decoder.bias": "cls.predictions.bias",
  394. "cls.predictions.decoder.weight": "transformer.embeddings.word_embeddings.weight",
  395. }
  396. def __init__(self, config):
  397. super().__init__(config)
  398. self.transformer = SqueezeBertModel(config)
  399. self.cls = SqueezeBertOnlyMLMHead(config)
  400. # Initialize weights and apply final processing
  401. self.post_init()
  402. def get_output_embeddings(self):
  403. return self.cls.predictions.decoder
  404. def set_output_embeddings(self, new_embeddings):
  405. self.cls.predictions.decoder = new_embeddings
  406. self.cls.predictions.bias = new_embeddings.bias
  407. @auto_docstring
  408. def forward(
  409. self,
  410. input_ids: torch.Tensor | None = None,
  411. attention_mask: torch.Tensor | None = None,
  412. token_type_ids: torch.Tensor | None = None,
  413. position_ids: torch.Tensor | None = None,
  414. inputs_embeds: torch.Tensor | None = None,
  415. labels: torch.Tensor | None = None,
  416. output_attentions: bool | None = None,
  417. output_hidden_states: bool | None = None,
  418. return_dict: bool | None = None,
  419. **kwargs,
  420. ) -> tuple | MaskedLMOutput:
  421. r"""
  422. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  423. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  424. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  425. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  426. """
  427. return_dict = return_dict if return_dict is not None else self.config.return_dict
  428. outputs = self.transformer(
  429. input_ids,
  430. attention_mask=attention_mask,
  431. token_type_ids=token_type_ids,
  432. position_ids=position_ids,
  433. inputs_embeds=inputs_embeds,
  434. output_attentions=output_attentions,
  435. output_hidden_states=output_hidden_states,
  436. return_dict=return_dict,
  437. )
  438. sequence_output = outputs[0]
  439. prediction_scores = self.cls(sequence_output)
  440. masked_lm_loss = None
  441. if labels is not None:
  442. loss_fct = CrossEntropyLoss() # -100 index = padding token
  443. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  444. if not return_dict:
  445. output = (prediction_scores,) + outputs[2:]
  446. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  447. return MaskedLMOutput(
  448. loss=masked_lm_loss,
  449. logits=prediction_scores,
  450. hidden_states=outputs.hidden_states,
  451. attentions=outputs.attentions,
  452. )
  453. @auto_docstring(
  454. custom_intro="""
  455. SqueezeBERT Model transformer with a sequence classification/regression head on top (a linear layer on top of the
  456. pooled output) e.g. for GLUE tasks.
  457. """
  458. )
  459. class SqueezeBertForSequenceClassification(SqueezeBertPreTrainedModel):
  460. def __init__(self, config):
  461. super().__init__(config)
  462. self.num_labels = config.num_labels
  463. self.config = config
  464. self.transformer = SqueezeBertModel(config)
  465. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  466. self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
  467. # Initialize weights and apply final processing
  468. self.post_init()
  469. @auto_docstring
  470. def forward(
  471. self,
  472. input_ids: torch.Tensor | None = None,
  473. attention_mask: torch.Tensor | None = None,
  474. token_type_ids: torch.Tensor | None = None,
  475. position_ids: torch.Tensor | None = None,
  476. inputs_embeds: torch.Tensor | None = None,
  477. labels: torch.Tensor | None = None,
  478. output_attentions: bool | None = None,
  479. output_hidden_states: bool | None = None,
  480. return_dict: bool | None = None,
  481. **kwargs,
  482. ) -> tuple | SequenceClassifierOutput:
  483. r"""
  484. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  485. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  486. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  487. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  488. """
  489. return_dict = return_dict if return_dict is not None else self.config.return_dict
  490. outputs = self.transformer(
  491. input_ids,
  492. attention_mask=attention_mask,
  493. token_type_ids=token_type_ids,
  494. position_ids=position_ids,
  495. inputs_embeds=inputs_embeds,
  496. output_attentions=output_attentions,
  497. output_hidden_states=output_hidden_states,
  498. return_dict=return_dict,
  499. )
  500. pooled_output = outputs[1]
  501. pooled_output = self.dropout(pooled_output)
  502. logits = self.classifier(pooled_output)
  503. loss = None
  504. if labels is not None:
  505. if self.config.problem_type is None:
  506. if self.num_labels == 1:
  507. self.config.problem_type = "regression"
  508. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  509. self.config.problem_type = "single_label_classification"
  510. else:
  511. self.config.problem_type = "multi_label_classification"
  512. if self.config.problem_type == "regression":
  513. loss_fct = MSELoss()
  514. if self.num_labels == 1:
  515. loss = loss_fct(logits.squeeze(), labels.squeeze())
  516. else:
  517. loss = loss_fct(logits, labels)
  518. elif self.config.problem_type == "single_label_classification":
  519. loss_fct = CrossEntropyLoss()
  520. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  521. elif self.config.problem_type == "multi_label_classification":
  522. loss_fct = BCEWithLogitsLoss()
  523. loss = loss_fct(logits, labels)
  524. if not return_dict:
  525. output = (logits,) + outputs[2:]
  526. return ((loss,) + output) if loss is not None else output
  527. return SequenceClassifierOutput(
  528. loss=loss,
  529. logits=logits,
  530. hidden_states=outputs.hidden_states,
  531. attentions=outputs.attentions,
  532. )
  533. @auto_docstring
  534. class SqueezeBertForMultipleChoice(SqueezeBertPreTrainedModel):
  535. def __init__(self, config):
  536. super().__init__(config)
  537. self.transformer = SqueezeBertModel(config)
  538. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  539. self.classifier = nn.Linear(config.hidden_size, 1)
  540. # Initialize weights and apply final processing
  541. self.post_init()
  542. @auto_docstring
  543. def forward(
  544. self,
  545. input_ids: torch.Tensor | None = None,
  546. attention_mask: torch.Tensor | None = None,
  547. token_type_ids: torch.Tensor | None = None,
  548. position_ids: torch.Tensor | None = None,
  549. inputs_embeds: torch.Tensor | None = None,
  550. labels: torch.Tensor | None = None,
  551. output_attentions: bool | None = None,
  552. output_hidden_states: bool | None = None,
  553. return_dict: bool | None = None,
  554. **kwargs,
  555. ) -> tuple | MultipleChoiceModelOutput:
  556. r"""
  557. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  558. Indices of input sequence tokens in the vocabulary.
  559. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  560. [`PreTrainedTokenizer.__call__`] for details.
  561. [What are input IDs?](../glossary#input-ids)
  562. token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  563. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  564. 1]`:
  565. - 0 corresponds to a *sentence A* token,
  566. - 1 corresponds to a *sentence B* token.
  567. [What are token type IDs?](../glossary#token-type-ids)
  568. position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  569. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  570. config.max_position_embeddings - 1]`.
  571. [What are position IDs?](../glossary#position-ids)
  572. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  573. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  574. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  575. model's internal embedding lookup matrix.
  576. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  577. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  578. num_choices-1]` where *num_choices* is the size of the second dimension of the input tensors. (see
  579. *input_ids* above)
  580. """
  581. return_dict = return_dict if return_dict is not None else self.config.return_dict
  582. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  583. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  584. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  585. token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  586. position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  587. inputs_embeds = (
  588. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  589. if inputs_embeds is not None
  590. else None
  591. )
  592. outputs = self.transformer(
  593. input_ids,
  594. attention_mask=attention_mask,
  595. token_type_ids=token_type_ids,
  596. position_ids=position_ids,
  597. inputs_embeds=inputs_embeds,
  598. output_attentions=output_attentions,
  599. output_hidden_states=output_hidden_states,
  600. return_dict=return_dict,
  601. )
  602. pooled_output = outputs[1]
  603. pooled_output = self.dropout(pooled_output)
  604. logits = self.classifier(pooled_output)
  605. reshaped_logits = logits.view(-1, num_choices)
  606. loss = None
  607. if labels is not None:
  608. loss_fct = CrossEntropyLoss()
  609. loss = loss_fct(reshaped_logits, labels)
  610. if not return_dict:
  611. output = (reshaped_logits,) + outputs[2:]
  612. return ((loss,) + output) if loss is not None else output
  613. return MultipleChoiceModelOutput(
  614. loss=loss,
  615. logits=reshaped_logits,
  616. hidden_states=outputs.hidden_states,
  617. attentions=outputs.attentions,
  618. )
  619. @auto_docstring
  620. class SqueezeBertForTokenClassification(SqueezeBertPreTrainedModel):
  621. def __init__(self, config):
  622. super().__init__(config)
  623. self.num_labels = config.num_labels
  624. self.transformer = SqueezeBertModel(config)
  625. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  626. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  627. # Initialize weights and apply final processing
  628. self.post_init()
  629. @auto_docstring
  630. def forward(
  631. self,
  632. input_ids: torch.Tensor | None = None,
  633. attention_mask: torch.Tensor | None = None,
  634. token_type_ids: torch.Tensor | None = None,
  635. position_ids: torch.Tensor | None = None,
  636. inputs_embeds: torch.Tensor | None = None,
  637. labels: torch.Tensor | None = None,
  638. output_attentions: bool | None = None,
  639. output_hidden_states: bool | None = None,
  640. return_dict: bool | None = None,
  641. **kwargs,
  642. ) -> tuple | TokenClassifierOutput:
  643. r"""
  644. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  645. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  646. """
  647. return_dict = return_dict if return_dict is not None else self.config.return_dict
  648. outputs = self.transformer(
  649. input_ids,
  650. attention_mask=attention_mask,
  651. token_type_ids=token_type_ids,
  652. position_ids=position_ids,
  653. inputs_embeds=inputs_embeds,
  654. output_attentions=output_attentions,
  655. output_hidden_states=output_hidden_states,
  656. return_dict=return_dict,
  657. )
  658. sequence_output = outputs[0]
  659. sequence_output = self.dropout(sequence_output)
  660. logits = self.classifier(sequence_output)
  661. loss = None
  662. if labels is not None:
  663. loss_fct = CrossEntropyLoss()
  664. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  665. if not return_dict:
  666. output = (logits,) + outputs[2:]
  667. return ((loss,) + output) if loss is not None else output
  668. return TokenClassifierOutput(
  669. loss=loss,
  670. logits=logits,
  671. hidden_states=outputs.hidden_states,
  672. attentions=outputs.attentions,
  673. )
  674. @auto_docstring
  675. class SqueezeBertForQuestionAnswering(SqueezeBertPreTrainedModel):
  676. def __init__(self, config):
  677. super().__init__(config)
  678. self.num_labels = config.num_labels
  679. self.transformer = SqueezeBertModel(config)
  680. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  681. # Initialize weights and apply final processing
  682. self.post_init()
  683. @auto_docstring
  684. def forward(
  685. self,
  686. input_ids: torch.Tensor | None = None,
  687. attention_mask: torch.Tensor | None = None,
  688. token_type_ids: torch.Tensor | None = None,
  689. position_ids: torch.Tensor | None = None,
  690. inputs_embeds: torch.Tensor | None = None,
  691. start_positions: torch.Tensor | None = None,
  692. end_positions: torch.Tensor | None = None,
  693. output_attentions: bool | None = None,
  694. output_hidden_states: bool | None = None,
  695. return_dict: bool | None = None,
  696. **kwargs,
  697. ) -> tuple | QuestionAnsweringModelOutput:
  698. return_dict = return_dict if return_dict is not None else self.config.return_dict
  699. outputs = self.transformer(
  700. input_ids,
  701. attention_mask=attention_mask,
  702. token_type_ids=token_type_ids,
  703. position_ids=position_ids,
  704. inputs_embeds=inputs_embeds,
  705. output_attentions=output_attentions,
  706. output_hidden_states=output_hidden_states,
  707. return_dict=return_dict,
  708. )
  709. sequence_output = outputs[0]
  710. logits = self.qa_outputs(sequence_output)
  711. start_logits, end_logits = logits.split(1, dim=-1)
  712. start_logits = start_logits.squeeze(-1).contiguous()
  713. end_logits = end_logits.squeeze(-1).contiguous()
  714. total_loss = None
  715. if start_positions is not None and end_positions is not None:
  716. # If we are on multi-GPU, split add a dimension
  717. if len(start_positions.size()) > 1:
  718. start_positions = start_positions.squeeze(-1)
  719. if len(end_positions.size()) > 1:
  720. end_positions = end_positions.squeeze(-1)
  721. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  722. ignored_index = start_logits.size(1)
  723. start_positions = start_positions.clamp(0, ignored_index)
  724. end_positions = end_positions.clamp(0, ignored_index)
  725. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  726. start_loss = loss_fct(start_logits, start_positions)
  727. end_loss = loss_fct(end_logits, end_positions)
  728. total_loss = (start_loss + end_loss) / 2
  729. if not return_dict:
  730. output = (start_logits, end_logits) + outputs[2:]
  731. return ((total_loss,) + output) if total_loss is not None else output
  732. return QuestionAnsweringModelOutput(
  733. loss=total_loss,
  734. start_logits=start_logits,
  735. end_logits=end_logits,
  736. hidden_states=outputs.hidden_states,
  737. attentions=outputs.attentions,
  738. )
  739. __all__ = [
  740. "SqueezeBertForMaskedLM",
  741. "SqueezeBertForMultipleChoice",
  742. "SqueezeBertForQuestionAnswering",
  743. "SqueezeBertForSequenceClassification",
  744. "SqueezeBertForTokenClassification",
  745. "SqueezeBertModel",
  746. "SqueezeBertModule",
  747. "SqueezeBertPreTrainedModel",
  748. ]