modeling_ibert.py 47 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198
  1. # Copyright 2021 The I-BERT Authors (Sehoon Kim, Amir Gholami, Zhewei Yao,
  2. # Michael Mahoney, Kurt Keutzer - UC Berkeley) and The HuggingFace Inc. team.
  3. # Copyright (c) 20121, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """PyTorch I-BERT model."""
  17. import math
  18. import torch
  19. from torch import nn
  20. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  21. from ... import initialization as init
  22. from ...activations import gelu
  23. from ...modeling_outputs import (
  24. BaseModelOutputWithPastAndCrossAttentions,
  25. BaseModelOutputWithPoolingAndCrossAttentions,
  26. MaskedLMOutput,
  27. MultipleChoiceModelOutput,
  28. QuestionAnsweringModelOutput,
  29. SequenceClassifierOutput,
  30. TokenClassifierOutput,
  31. )
  32. from ...modeling_utils import PreTrainedModel
  33. from ...utils import auto_docstring, logging
  34. from .configuration_ibert import IBertConfig
  35. from .quant_modules import IntGELU, IntLayerNorm, IntSoftmax, QuantAct, QuantEmbedding, QuantLinear
  36. logger = logging.get_logger(__name__)
  37. class IBertEmbeddings(nn.Module):
  38. """
  39. Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
  40. """
  41. def __init__(self, config):
  42. super().__init__()
  43. self.quant_mode = config.quant_mode
  44. self.embedding_bit = 8
  45. self.embedding_act_bit = 16
  46. self.act_bit = 8
  47. self.ln_input_bit = 22
  48. self.ln_output_bit = 32
  49. self.word_embeddings = QuantEmbedding(
  50. config.vocab_size,
  51. config.hidden_size,
  52. padding_idx=config.pad_token_id,
  53. weight_bit=self.embedding_bit,
  54. quant_mode=self.quant_mode,
  55. )
  56. self.token_type_embeddings = QuantEmbedding(
  57. config.type_vocab_size, config.hidden_size, weight_bit=self.embedding_bit, quant_mode=self.quant_mode
  58. )
  59. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  60. self.register_buffer(
  61. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  62. )
  63. # End copy
  64. self.padding_idx = config.pad_token_id
  65. self.position_embeddings = QuantEmbedding(
  66. config.max_position_embeddings,
  67. config.hidden_size,
  68. padding_idx=self.padding_idx,
  69. weight_bit=self.embedding_bit,
  70. quant_mode=self.quant_mode,
  71. )
  72. # Integer-only addition between embeddings
  73. self.embeddings_act1 = QuantAct(self.embedding_act_bit, quant_mode=self.quant_mode)
  74. self.embeddings_act2 = QuantAct(self.embedding_act_bit, quant_mode=self.quant_mode)
  75. self.LayerNorm = IntLayerNorm(
  76. config.hidden_size,
  77. eps=config.layer_norm_eps,
  78. output_bit=self.ln_output_bit,
  79. quant_mode=self.quant_mode,
  80. force_dequant=config.force_dequant,
  81. )
  82. self.output_activation = QuantAct(self.act_bit, quant_mode=self.quant_mode)
  83. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  84. def forward(
  85. self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
  86. ):
  87. if position_ids is None:
  88. if input_ids is not None:
  89. # Create the position ids from the input token ids. Any padded tokens remain padded.
  90. position_ids = create_position_ids_from_input_ids(
  91. input_ids, self.padding_idx, past_key_values_length
  92. ).to(input_ids.device)
  93. else:
  94. position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
  95. if input_ids is not None:
  96. input_shape = input_ids.size()
  97. else:
  98. input_shape = inputs_embeds.size()[:-1]
  99. if token_type_ids is None:
  100. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  101. if inputs_embeds is None:
  102. inputs_embeds, inputs_embeds_scaling_factor = self.word_embeddings(input_ids)
  103. else:
  104. inputs_embeds_scaling_factor = None
  105. token_type_embeddings, token_type_embeddings_scaling_factor = self.token_type_embeddings(token_type_ids)
  106. embeddings, embeddings_scaling_factor = self.embeddings_act1(
  107. inputs_embeds,
  108. inputs_embeds_scaling_factor,
  109. identity=token_type_embeddings,
  110. identity_scaling_factor=token_type_embeddings_scaling_factor,
  111. )
  112. position_embeddings, position_embeddings_scaling_factor = self.position_embeddings(position_ids)
  113. embeddings, embeddings_scaling_factor = self.embeddings_act1(
  114. embeddings,
  115. embeddings_scaling_factor,
  116. identity=position_embeddings,
  117. identity_scaling_factor=position_embeddings_scaling_factor,
  118. )
  119. embeddings, embeddings_scaling_factor = self.LayerNorm(embeddings, embeddings_scaling_factor)
  120. embeddings = self.dropout(embeddings)
  121. embeddings, embeddings_scaling_factor = self.output_activation(embeddings, embeddings_scaling_factor)
  122. return embeddings, embeddings_scaling_factor
  123. def create_position_ids_from_inputs_embeds(self, inputs_embeds):
  124. """
  125. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
  126. Args:
  127. inputs_embeds: torch.Tensor
  128. Returns: torch.Tensor
  129. """
  130. input_shape = inputs_embeds.size()[:-1]
  131. sequence_length = input_shape[1]
  132. position_ids = torch.arange(
  133. self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
  134. )
  135. return position_ids.unsqueeze(0).expand(input_shape)
  136. class IBertSelfAttention(nn.Module):
  137. def __init__(self, config):
  138. super().__init__()
  139. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  140. raise ValueError(
  141. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  142. f"heads ({config.num_attention_heads})"
  143. )
  144. self.quant_mode = config.quant_mode
  145. self.weight_bit = 8
  146. self.bias_bit = 32
  147. self.act_bit = 8
  148. self.num_attention_heads = config.num_attention_heads
  149. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  150. self.all_head_size = self.num_attention_heads * self.attention_head_size
  151. # Q, K, V Linear layers
  152. self.query = QuantLinear(
  153. config.hidden_size,
  154. self.all_head_size,
  155. bias=True,
  156. weight_bit=self.weight_bit,
  157. bias_bit=self.bias_bit,
  158. quant_mode=self.quant_mode,
  159. per_channel=True,
  160. )
  161. self.key = QuantLinear(
  162. config.hidden_size,
  163. self.all_head_size,
  164. bias=True,
  165. weight_bit=self.weight_bit,
  166. bias_bit=self.bias_bit,
  167. quant_mode=self.quant_mode,
  168. per_channel=True,
  169. )
  170. self.value = QuantLinear(
  171. config.hidden_size,
  172. self.all_head_size,
  173. bias=True,
  174. weight_bit=self.weight_bit,
  175. bias_bit=self.bias_bit,
  176. quant_mode=self.quant_mode,
  177. per_channel=True,
  178. )
  179. # Requantization (32bit -> 8bit) for Q, K, V activations
  180. self.query_activation = QuantAct(self.act_bit, quant_mode=self.quant_mode)
  181. self.key_activation = QuantAct(self.act_bit, quant_mode=self.quant_mode)
  182. self.value_activation = QuantAct(self.act_bit, quant_mode=self.quant_mode)
  183. self.output_activation = QuantAct(self.act_bit, quant_mode=self.quant_mode)
  184. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  185. self.softmax = IntSoftmax(self.act_bit, quant_mode=self.quant_mode, force_dequant=config.force_dequant)
  186. def forward(
  187. self,
  188. hidden_states,
  189. hidden_states_scaling_factor,
  190. attention_mask=None,
  191. output_attentions=False,
  192. ):
  193. # Projection
  194. mixed_query_layer, mixed_query_layer_scaling_factor = self.query(hidden_states, hidden_states_scaling_factor)
  195. mixed_key_layer, mixed_key_layer_scaling_factor = self.key(hidden_states, hidden_states_scaling_factor)
  196. mixed_value_layer, mixed_value_layer_scaling_factor = self.value(hidden_states, hidden_states_scaling_factor)
  197. # Requantization
  198. query_layer, query_layer_scaling_factor = self.query_activation(
  199. mixed_query_layer, mixed_query_layer_scaling_factor
  200. )
  201. key_layer, key_layer_scaling_factor = self.key_activation(mixed_key_layer, mixed_key_layer_scaling_factor)
  202. value_layer, value_layer_scaling_factor = self.value_activation(
  203. mixed_value_layer, mixed_value_layer_scaling_factor
  204. )
  205. # Transpose
  206. input_shape = hidden_states.shape[:-1]
  207. hidden_shape = (*input_shape, -1, self.attention_head_size)
  208. query_layer = query_layer.view(hidden_shape).transpose(1, 2)
  209. key_layer = key_layer.view(hidden_shape).transpose(1, 2)
  210. value_layer = value_layer.view(hidden_shape).transpose(1, 2)
  211. # Take the dot product between "query" and "key" to get the raw attention scores.
  212. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  213. scale = math.sqrt(self.attention_head_size)
  214. attention_scores = attention_scores / scale
  215. if self.quant_mode:
  216. attention_scores_scaling_factor = query_layer_scaling_factor * key_layer_scaling_factor / scale
  217. else:
  218. attention_scores_scaling_factor = None
  219. if attention_mask is not None:
  220. # Apply the attention mask is (precomputed for all layers in IBertModel forward() function)
  221. attention_scores = attention_scores + attention_mask
  222. # Normalize the attention scores to probabilities.
  223. attention_probs, attention_probs_scaling_factor = self.softmax(
  224. attention_scores, attention_scores_scaling_factor
  225. )
  226. # This is actually dropping out entire tokens to attend to, which might
  227. # seem a bit unusual, but is taken from the original Transformer paper.
  228. attention_probs = self.dropout(attention_probs)
  229. context_layer = torch.matmul(attention_probs, value_layer)
  230. if attention_probs_scaling_factor is not None:
  231. context_layer_scaling_factor = attention_probs_scaling_factor * value_layer_scaling_factor
  232. else:
  233. context_layer_scaling_factor = None
  234. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  235. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  236. context_layer = context_layer.view(*new_context_layer_shape)
  237. # requantization: 32-bit -> 8-bit
  238. context_layer, context_layer_scaling_factor = self.output_activation(
  239. context_layer, context_layer_scaling_factor
  240. )
  241. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  242. output_scaling_factor = (
  243. (context_layer_scaling_factor, attention_probs_scaling_factor)
  244. if output_attentions
  245. else (context_layer_scaling_factor,)
  246. )
  247. return outputs, output_scaling_factor
  248. class IBertSelfOutput(nn.Module):
  249. def __init__(self, config):
  250. super().__init__()
  251. self.quant_mode = config.quant_mode
  252. self.act_bit = 8
  253. self.weight_bit = 8
  254. self.bias_bit = 32
  255. self.ln_input_bit = 22
  256. self.ln_output_bit = 32
  257. self.dense = QuantLinear(
  258. config.hidden_size,
  259. config.hidden_size,
  260. bias=True,
  261. weight_bit=self.weight_bit,
  262. bias_bit=self.bias_bit,
  263. quant_mode=self.quant_mode,
  264. per_channel=True,
  265. )
  266. self.ln_input_act = QuantAct(self.ln_input_bit, quant_mode=self.quant_mode)
  267. self.LayerNorm = IntLayerNorm(
  268. config.hidden_size,
  269. eps=config.layer_norm_eps,
  270. output_bit=self.ln_output_bit,
  271. quant_mode=self.quant_mode,
  272. force_dequant=config.force_dequant,
  273. )
  274. self.output_activation = QuantAct(self.act_bit, quant_mode=self.quant_mode)
  275. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  276. def forward(self, hidden_states, hidden_states_scaling_factor, input_tensor, input_tensor_scaling_factor):
  277. hidden_states, hidden_states_scaling_factor = self.dense(hidden_states, hidden_states_scaling_factor)
  278. hidden_states = self.dropout(hidden_states)
  279. hidden_states, hidden_states_scaling_factor = self.ln_input_act(
  280. hidden_states,
  281. hidden_states_scaling_factor,
  282. identity=input_tensor,
  283. identity_scaling_factor=input_tensor_scaling_factor,
  284. )
  285. hidden_states, hidden_states_scaling_factor = self.LayerNorm(hidden_states, hidden_states_scaling_factor)
  286. hidden_states, hidden_states_scaling_factor = self.output_activation(
  287. hidden_states, hidden_states_scaling_factor
  288. )
  289. return hidden_states, hidden_states_scaling_factor
  290. class IBertAttention(nn.Module):
  291. def __init__(self, config):
  292. super().__init__()
  293. self.quant_mode = config.quant_mode
  294. self.self = IBertSelfAttention(config)
  295. self.output = IBertSelfOutput(config)
  296. def forward(
  297. self,
  298. hidden_states,
  299. hidden_states_scaling_factor,
  300. attention_mask=None,
  301. output_attentions=False,
  302. ):
  303. self_outputs, self_outputs_scaling_factor = self.self(
  304. hidden_states,
  305. hidden_states_scaling_factor,
  306. attention_mask,
  307. output_attentions,
  308. )
  309. attention_output, attention_output_scaling_factor = self.output(
  310. self_outputs[0], self_outputs_scaling_factor[0], hidden_states, hidden_states_scaling_factor
  311. )
  312. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  313. outputs_scaling_factor = (attention_output_scaling_factor,) + self_outputs_scaling_factor[1:]
  314. return outputs, outputs_scaling_factor
  315. class IBertIntermediate(nn.Module):
  316. def __init__(self, config):
  317. super().__init__()
  318. self.quant_mode = config.quant_mode
  319. self.act_bit = 8
  320. self.weight_bit = 8
  321. self.bias_bit = 32
  322. self.dense = QuantLinear(
  323. config.hidden_size,
  324. config.intermediate_size,
  325. bias=True,
  326. weight_bit=self.weight_bit,
  327. bias_bit=self.bias_bit,
  328. quant_mode=self.quant_mode,
  329. per_channel=True,
  330. )
  331. if config.hidden_act != "gelu":
  332. raise ValueError("I-BERT only supports 'gelu' for `config.hidden_act`")
  333. self.intermediate_act_fn = IntGELU(quant_mode=self.quant_mode, force_dequant=config.force_dequant)
  334. self.output_activation = QuantAct(self.act_bit, quant_mode=self.quant_mode)
  335. def forward(self, hidden_states, hidden_states_scaling_factor):
  336. hidden_states, hidden_states_scaling_factor = self.dense(hidden_states, hidden_states_scaling_factor)
  337. hidden_states, hidden_states_scaling_factor = self.intermediate_act_fn(
  338. hidden_states, hidden_states_scaling_factor
  339. )
  340. # Requantization: 32bit -> 8-bit
  341. hidden_states, hidden_states_scaling_factor = self.output_activation(
  342. hidden_states, hidden_states_scaling_factor
  343. )
  344. return hidden_states, hidden_states_scaling_factor
  345. class IBertOutput(nn.Module):
  346. def __init__(self, config):
  347. super().__init__()
  348. self.quant_mode = config.quant_mode
  349. self.act_bit = 8
  350. self.weight_bit = 8
  351. self.bias_bit = 32
  352. self.ln_input_bit = 22
  353. self.ln_output_bit = 32
  354. self.dense = QuantLinear(
  355. config.intermediate_size,
  356. config.hidden_size,
  357. bias=True,
  358. weight_bit=self.weight_bit,
  359. bias_bit=self.bias_bit,
  360. quant_mode=self.quant_mode,
  361. per_channel=True,
  362. )
  363. self.ln_input_act = QuantAct(self.ln_input_bit, quant_mode=self.quant_mode)
  364. self.LayerNorm = IntLayerNorm(
  365. config.hidden_size,
  366. eps=config.layer_norm_eps,
  367. output_bit=self.ln_output_bit,
  368. quant_mode=self.quant_mode,
  369. force_dequant=config.force_dequant,
  370. )
  371. self.output_activation = QuantAct(self.act_bit, quant_mode=self.quant_mode)
  372. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  373. def forward(self, hidden_states, hidden_states_scaling_factor, input_tensor, input_tensor_scaling_factor):
  374. hidden_states, hidden_states_scaling_factor = self.dense(hidden_states, hidden_states_scaling_factor)
  375. hidden_states = self.dropout(hidden_states)
  376. hidden_states, hidden_states_scaling_factor = self.ln_input_act(
  377. hidden_states,
  378. hidden_states_scaling_factor,
  379. identity=input_tensor,
  380. identity_scaling_factor=input_tensor_scaling_factor,
  381. )
  382. hidden_states, hidden_states_scaling_factor = self.LayerNorm(hidden_states, hidden_states_scaling_factor)
  383. hidden_states, hidden_states_scaling_factor = self.output_activation(
  384. hidden_states, hidden_states_scaling_factor
  385. )
  386. return hidden_states, hidden_states_scaling_factor
  387. class IBertLayer(nn.Module):
  388. def __init__(self, config):
  389. super().__init__()
  390. self.quant_mode = config.quant_mode
  391. self.act_bit = 8
  392. self.seq_len_dim = 1
  393. self.attention = IBertAttention(config)
  394. self.intermediate = IBertIntermediate(config)
  395. self.output = IBertOutput(config)
  396. self.pre_intermediate_act = QuantAct(self.act_bit, quant_mode=self.quant_mode)
  397. self.pre_output_act = QuantAct(self.act_bit, quant_mode=self.quant_mode)
  398. def forward(
  399. self,
  400. hidden_states,
  401. hidden_states_scaling_factor,
  402. attention_mask=None,
  403. output_attentions=False,
  404. ):
  405. self_attention_outputs, self_attention_outputs_scaling_factor = self.attention(
  406. hidden_states,
  407. hidden_states_scaling_factor,
  408. attention_mask,
  409. output_attentions=output_attentions,
  410. )
  411. attention_output = self_attention_outputs[0]
  412. attention_output_scaling_factor = self_attention_outputs_scaling_factor[0]
  413. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  414. layer_output, layer_output_scaling_factor = self.feed_forward_chunk(
  415. attention_output, attention_output_scaling_factor
  416. )
  417. outputs = (layer_output,) + outputs
  418. return outputs
  419. def feed_forward_chunk(self, attention_output, attention_output_scaling_factor):
  420. attention_output, attention_output_scaling_factor = self.pre_intermediate_act(
  421. attention_output, attention_output_scaling_factor
  422. )
  423. intermediate_output, intermediate_output_scaling_factor = self.intermediate(
  424. attention_output, attention_output_scaling_factor
  425. )
  426. intermediate_output, intermediate_output_scaling_factor = self.pre_output_act(
  427. intermediate_output, intermediate_output_scaling_factor
  428. )
  429. layer_output, layer_output_scaling_factor = self.output(
  430. intermediate_output, intermediate_output_scaling_factor, attention_output, attention_output_scaling_factor
  431. )
  432. return layer_output, layer_output_scaling_factor
  433. class IBertEncoder(nn.Module):
  434. def __init__(self, config):
  435. super().__init__()
  436. self.config = config
  437. self.quant_mode = config.quant_mode
  438. self.layer = nn.ModuleList([IBertLayer(config) for _ in range(config.num_hidden_layers)])
  439. def forward(
  440. self,
  441. hidden_states,
  442. hidden_states_scaling_factor,
  443. attention_mask=None,
  444. output_attentions=False,
  445. output_hidden_states=False,
  446. return_dict=True,
  447. ):
  448. all_hidden_states = () if output_hidden_states else None
  449. all_self_attentions = () if output_attentions else None
  450. all_cross_attentions = None # `config.add_cross_attention` is not supported
  451. for i, layer_module in enumerate(self.layer):
  452. if output_hidden_states:
  453. all_hidden_states = all_hidden_states + (hidden_states,)
  454. layer_outputs = layer_module(
  455. hidden_states,
  456. hidden_states_scaling_factor,
  457. attention_mask,
  458. output_attentions,
  459. )
  460. hidden_states = layer_outputs[0]
  461. if output_attentions:
  462. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  463. if output_hidden_states:
  464. all_hidden_states = all_hidden_states + (hidden_states,)
  465. if not return_dict:
  466. return tuple(
  467. v
  468. for v in [
  469. hidden_states,
  470. all_hidden_states,
  471. all_self_attentions,
  472. all_cross_attentions,
  473. ]
  474. if v is not None
  475. )
  476. return BaseModelOutputWithPastAndCrossAttentions(
  477. last_hidden_state=hidden_states,
  478. hidden_states=all_hidden_states,
  479. attentions=all_self_attentions,
  480. cross_attentions=all_cross_attentions,
  481. )
  482. class IBertPooler(nn.Module):
  483. def __init__(self, config):
  484. super().__init__()
  485. self.quant_mode = config.quant_mode
  486. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  487. self.activation = nn.Tanh()
  488. def forward(self, hidden_states):
  489. # We "pool" the model by simply taking the hidden state corresponding
  490. # to the first token.
  491. first_token_tensor = hidden_states[:, 0]
  492. pooled_output = self.dense(first_token_tensor)
  493. pooled_output = self.activation(pooled_output)
  494. return pooled_output
  495. @auto_docstring
  496. class IBertPreTrainedModel(PreTrainedModel):
  497. config: IBertConfig
  498. base_model_prefix = "ibert"
  499. @torch.no_grad()
  500. def _init_weights(self, module):
  501. """Initialize the weights"""
  502. if isinstance(module, (QuantLinear, nn.Linear)):
  503. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  504. if module.bias is not None:
  505. init.zeros_(module.bias)
  506. if getattr(module, "weight_integer", None) is not None:
  507. init.zeros_(module.weight_integer)
  508. init.zeros_(module.fc_scaling_factor)
  509. if getattr(module, "bias_integer", None) is not None:
  510. init.zeros_(module.bias_integer)
  511. elif isinstance(module, (QuantEmbedding, nn.Embedding)):
  512. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  513. # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
  514. if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
  515. init.zeros_(module.weight[module.padding_idx])
  516. if getattr(module, "weight_scaling_factor", None) is not None:
  517. init.zeros_(module.weight_scaling_factor)
  518. init.zeros_(module.weight_integer)
  519. elif isinstance(module, (IntLayerNorm, nn.LayerNorm)):
  520. init.zeros_(module.bias)
  521. init.ones_(module.weight)
  522. if getattr(module, "shift", None) is not None:
  523. init.zeros_(module.shift)
  524. elif isinstance(module, IBertLMHead):
  525. init.zeros_(module.bias)
  526. elif isinstance(module, IBertEmbeddings):
  527. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  528. elif isinstance(module, QuantAct):
  529. init.constant_(module.x_min, -1e-5)
  530. init.constant_(module.x_max, 1e-5)
  531. init.zeros_(module.act_scaling_factor)
  532. def resize_token_embeddings(self, new_num_tokens=None):
  533. raise NotImplementedError("`resize_token_embeddings` is not supported for I-BERT.")
  534. @auto_docstring
  535. class IBertModel(IBertPreTrainedModel):
  536. """
  537. The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
  538. cross-attention is added between the self-attention layers, following the architecture described in [Attention is
  539. all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
  540. Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
  541. """
  542. def __init__(self, config, add_pooling_layer=True):
  543. r"""
  544. add_pooling_layer (bool, *optional*, defaults to `True`):
  545. Whether to add a pooling layer
  546. """
  547. super().__init__(config)
  548. self.config = config
  549. self.quant_mode = config.quant_mode
  550. self.embeddings = IBertEmbeddings(config)
  551. self.encoder = IBertEncoder(config)
  552. self.pooler = IBertPooler(config) if add_pooling_layer else None
  553. # Initialize weights and apply final processing
  554. self.post_init()
  555. def get_input_embeddings(self):
  556. return self.embeddings.word_embeddings
  557. def set_input_embeddings(self, value):
  558. self.embeddings.word_embeddings = value
  559. @auto_docstring
  560. def forward(
  561. self,
  562. input_ids: torch.LongTensor | None = None,
  563. attention_mask: torch.FloatTensor | None = None,
  564. token_type_ids: torch.LongTensor | None = None,
  565. position_ids: torch.LongTensor | None = None,
  566. inputs_embeds: torch.FloatTensor | None = None,
  567. output_attentions: bool | None = None,
  568. output_hidden_states: bool | None = None,
  569. return_dict: bool | None = None,
  570. **kwargs,
  571. ) -> BaseModelOutputWithPoolingAndCrossAttentions | tuple[torch.FloatTensor]:
  572. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  573. output_hidden_states = (
  574. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  575. )
  576. return_dict = return_dict if return_dict is not None else self.config.return_dict
  577. if input_ids is not None and inputs_embeds is not None:
  578. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  579. elif input_ids is not None:
  580. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  581. input_shape = input_ids.size()
  582. elif inputs_embeds is not None:
  583. input_shape = inputs_embeds.size()[:-1]
  584. else:
  585. raise ValueError("You have to specify either input_ids or inputs_embeds")
  586. batch_size, seq_length = input_shape
  587. device = input_ids.device if input_ids is not None else inputs_embeds.device
  588. if attention_mask is None:
  589. attention_mask = torch.ones(((batch_size, seq_length)), device=device)
  590. if token_type_ids is None:
  591. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  592. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  593. # ourselves in which case we just need to make it broadcastable to all heads.
  594. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
  595. embedding_output, embedding_output_scaling_factor = self.embeddings(
  596. input_ids=input_ids,
  597. position_ids=position_ids,
  598. token_type_ids=token_type_ids,
  599. inputs_embeds=inputs_embeds,
  600. )
  601. encoder_outputs = self.encoder(
  602. embedding_output,
  603. embedding_output_scaling_factor,
  604. attention_mask=extended_attention_mask,
  605. output_attentions=output_attentions,
  606. output_hidden_states=output_hidden_states,
  607. return_dict=return_dict,
  608. )
  609. sequence_output = encoder_outputs[0]
  610. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  611. if not return_dict:
  612. return (sequence_output, pooled_output) + encoder_outputs[1:]
  613. return BaseModelOutputWithPoolingAndCrossAttentions(
  614. last_hidden_state=sequence_output,
  615. pooler_output=pooled_output,
  616. hidden_states=encoder_outputs.hidden_states,
  617. attentions=encoder_outputs.attentions,
  618. cross_attentions=encoder_outputs.cross_attentions,
  619. )
  620. @auto_docstring
  621. class IBertForMaskedLM(IBertPreTrainedModel):
  622. _tied_weights_keys = {
  623. "lm_head.decoder.weight": "ibert.embeddings.word_embeddings.weight$",
  624. "lm_head.decoder.bias": "lm_head.bias",
  625. }
  626. def __init__(self, config):
  627. super().__init__(config)
  628. self.ibert = IBertModel(config, add_pooling_layer=False)
  629. self.lm_head = IBertLMHead(config)
  630. # Initialize weights and apply final processing
  631. self.post_init()
  632. def get_output_embeddings(self):
  633. return self.lm_head.decoder
  634. def set_output_embeddings(self, new_embeddings):
  635. self.lm_head.decoder = new_embeddings
  636. self.lm_head.bias = new_embeddings.bias
  637. @auto_docstring
  638. def forward(
  639. self,
  640. input_ids: torch.LongTensor | None = None,
  641. attention_mask: torch.FloatTensor | None = None,
  642. token_type_ids: torch.LongTensor | None = None,
  643. position_ids: torch.LongTensor | None = None,
  644. inputs_embeds: torch.FloatTensor | None = None,
  645. labels: torch.LongTensor | None = None,
  646. output_attentions: bool | None = None,
  647. output_hidden_states: bool | None = None,
  648. return_dict: bool | None = None,
  649. **kwargs,
  650. ) -> MaskedLMOutput | tuple[torch.FloatTensor]:
  651. r"""
  652. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  653. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  654. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  655. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  656. """
  657. return_dict = return_dict if return_dict is not None else self.config.return_dict
  658. outputs = self.ibert(
  659. input_ids,
  660. attention_mask=attention_mask,
  661. token_type_ids=token_type_ids,
  662. position_ids=position_ids,
  663. inputs_embeds=inputs_embeds,
  664. output_attentions=output_attentions,
  665. output_hidden_states=output_hidden_states,
  666. return_dict=return_dict,
  667. )
  668. sequence_output = outputs[0]
  669. prediction_scores = self.lm_head(sequence_output)
  670. masked_lm_loss = None
  671. if labels is not None:
  672. loss_fct = CrossEntropyLoss()
  673. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  674. if not return_dict:
  675. output = (prediction_scores,) + outputs[2:]
  676. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  677. return MaskedLMOutput(
  678. loss=masked_lm_loss,
  679. logits=prediction_scores,
  680. hidden_states=outputs.hidden_states,
  681. attentions=outputs.attentions,
  682. )
  683. class IBertLMHead(nn.Module):
  684. """I-BERT Head for masked language modeling."""
  685. def __init__(self, config):
  686. super().__init__()
  687. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  688. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  689. self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
  690. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  691. def forward(self, features, **kwargs):
  692. x = self.dense(features)
  693. x = gelu(x)
  694. x = self.layer_norm(x)
  695. # project back to size of vocabulary with bias
  696. x = self.decoder(x)
  697. return x
  698. @auto_docstring(
  699. custom_intro="""
  700. I-BERT Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
  701. output) e.g. for GLUE tasks.
  702. """
  703. )
  704. class IBertForSequenceClassification(IBertPreTrainedModel):
  705. def __init__(self, config):
  706. super().__init__(config)
  707. self.num_labels = config.num_labels
  708. self.ibert = IBertModel(config, add_pooling_layer=False)
  709. self.classifier = IBertClassificationHead(config)
  710. # Initialize weights and apply final processing
  711. self.post_init()
  712. @auto_docstring
  713. def forward(
  714. self,
  715. input_ids: torch.LongTensor | None = None,
  716. attention_mask: torch.FloatTensor | None = None,
  717. token_type_ids: torch.LongTensor | None = None,
  718. position_ids: torch.LongTensor | None = None,
  719. inputs_embeds: torch.FloatTensor | None = None,
  720. labels: torch.LongTensor | None = None,
  721. output_attentions: bool | None = None,
  722. output_hidden_states: bool | None = None,
  723. return_dict: bool | None = None,
  724. **kwargs,
  725. ) -> SequenceClassifierOutput | tuple[torch.FloatTensor]:
  726. r"""
  727. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  728. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  729. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  730. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  731. """
  732. return_dict = return_dict if return_dict is not None else self.config.return_dict
  733. outputs = self.ibert(
  734. input_ids,
  735. attention_mask=attention_mask,
  736. token_type_ids=token_type_ids,
  737. position_ids=position_ids,
  738. inputs_embeds=inputs_embeds,
  739. output_attentions=output_attentions,
  740. output_hidden_states=output_hidden_states,
  741. return_dict=return_dict,
  742. )
  743. sequence_output = outputs[0]
  744. logits = self.classifier(sequence_output)
  745. loss = None
  746. if labels is not None:
  747. if self.config.problem_type is None:
  748. if self.num_labels == 1:
  749. self.config.problem_type = "regression"
  750. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  751. self.config.problem_type = "single_label_classification"
  752. else:
  753. self.config.problem_type = "multi_label_classification"
  754. if self.config.problem_type == "regression":
  755. loss_fct = MSELoss()
  756. if self.num_labels == 1:
  757. loss = loss_fct(logits.squeeze(), labels.squeeze())
  758. else:
  759. loss = loss_fct(logits, labels)
  760. elif self.config.problem_type == "single_label_classification":
  761. loss_fct = CrossEntropyLoss()
  762. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  763. elif self.config.problem_type == "multi_label_classification":
  764. loss_fct = BCEWithLogitsLoss()
  765. loss = loss_fct(logits, labels)
  766. if not return_dict:
  767. output = (logits,) + outputs[2:]
  768. return ((loss,) + output) if loss is not None else output
  769. return SequenceClassifierOutput(
  770. loss=loss,
  771. logits=logits,
  772. hidden_states=outputs.hidden_states,
  773. attentions=outputs.attentions,
  774. )
  775. @auto_docstring
  776. class IBertForMultipleChoice(IBertPreTrainedModel):
  777. def __init__(self, config):
  778. super().__init__(config)
  779. self.ibert = IBertModel(config)
  780. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  781. self.classifier = nn.Linear(config.hidden_size, 1)
  782. # Initialize weights and apply final processing
  783. self.post_init()
  784. @auto_docstring
  785. def forward(
  786. self,
  787. input_ids: torch.LongTensor | None = None,
  788. token_type_ids: torch.LongTensor | None = None,
  789. attention_mask: torch.FloatTensor | None = None,
  790. labels: torch.LongTensor | None = None,
  791. position_ids: torch.LongTensor | None = None,
  792. inputs_embeds: torch.FloatTensor | None = None,
  793. output_attentions: bool | None = None,
  794. output_hidden_states: bool | None = None,
  795. return_dict: bool | None = None,
  796. **kwargs,
  797. ) -> MultipleChoiceModelOutput | tuple[torch.FloatTensor]:
  798. r"""
  799. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  800. Indices of input sequence tokens in the vocabulary.
  801. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  802. [`PreTrainedTokenizer.__call__`] for details.
  803. [What are input IDs?](../glossary#input-ids)
  804. token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  805. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  806. 1]`:
  807. - 0 corresponds to a *sentence A* token,
  808. - 1 corresponds to a *sentence B* token.
  809. [What are token type IDs?](../glossary#token-type-ids)
  810. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  811. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  812. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  813. `input_ids` above)
  814. position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  815. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  816. config.max_position_embeddings - 1]`.
  817. [What are position IDs?](../glossary#position-ids)
  818. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  819. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  820. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  821. model's internal embedding lookup matrix.
  822. """
  823. return_dict = return_dict if return_dict is not None else self.config.return_dict
  824. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  825. flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  826. flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  827. flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  828. flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  829. flat_inputs_embeds = (
  830. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  831. if inputs_embeds is not None
  832. else None
  833. )
  834. outputs = self.ibert(
  835. flat_input_ids,
  836. position_ids=flat_position_ids,
  837. token_type_ids=flat_token_type_ids,
  838. attention_mask=flat_attention_mask,
  839. inputs_embeds=flat_inputs_embeds,
  840. output_attentions=output_attentions,
  841. output_hidden_states=output_hidden_states,
  842. return_dict=return_dict,
  843. )
  844. pooled_output = outputs[1]
  845. pooled_output = self.dropout(pooled_output)
  846. logits = self.classifier(pooled_output)
  847. reshaped_logits = logits.view(-1, num_choices)
  848. loss = None
  849. if labels is not None:
  850. loss_fct = CrossEntropyLoss()
  851. loss = loss_fct(reshaped_logits, labels)
  852. if not return_dict:
  853. output = (reshaped_logits,) + outputs[2:]
  854. return ((loss,) + output) if loss is not None else output
  855. return MultipleChoiceModelOutput(
  856. loss=loss,
  857. logits=reshaped_logits,
  858. hidden_states=outputs.hidden_states,
  859. attentions=outputs.attentions,
  860. )
  861. @auto_docstring
  862. class IBertForTokenClassification(IBertPreTrainedModel):
  863. def __init__(self, config):
  864. super().__init__(config)
  865. self.num_labels = config.num_labels
  866. self.ibert = IBertModel(config, add_pooling_layer=False)
  867. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  868. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  869. # Initialize weights and apply final processing
  870. self.post_init()
  871. @auto_docstring
  872. def forward(
  873. self,
  874. input_ids: torch.LongTensor | None = None,
  875. attention_mask: torch.FloatTensor | None = None,
  876. token_type_ids: torch.LongTensor | None = None,
  877. position_ids: torch.LongTensor | None = None,
  878. inputs_embeds: torch.FloatTensor | None = None,
  879. labels: torch.LongTensor | None = None,
  880. output_attentions: bool | None = None,
  881. output_hidden_states: bool | None = None,
  882. return_dict: bool | None = None,
  883. **kwargs,
  884. ) -> TokenClassifierOutput | tuple[torch.FloatTensor]:
  885. r"""
  886. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  887. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  888. """
  889. return_dict = return_dict if return_dict is not None else self.config.return_dict
  890. outputs = self.ibert(
  891. input_ids,
  892. attention_mask=attention_mask,
  893. token_type_ids=token_type_ids,
  894. position_ids=position_ids,
  895. inputs_embeds=inputs_embeds,
  896. output_attentions=output_attentions,
  897. output_hidden_states=output_hidden_states,
  898. return_dict=return_dict,
  899. )
  900. sequence_output = outputs[0]
  901. sequence_output = self.dropout(sequence_output)
  902. logits = self.classifier(sequence_output)
  903. loss = None
  904. if labels is not None:
  905. loss_fct = CrossEntropyLoss()
  906. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  907. if not return_dict:
  908. output = (logits,) + outputs[2:]
  909. return ((loss,) + output) if loss is not None else output
  910. return TokenClassifierOutput(
  911. loss=loss,
  912. logits=logits,
  913. hidden_states=outputs.hidden_states,
  914. attentions=outputs.attentions,
  915. )
  916. class IBertClassificationHead(nn.Module):
  917. """Head for sentence-level classification tasks."""
  918. def __init__(self, config):
  919. super().__init__()
  920. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  921. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  922. self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
  923. def forward(self, features, **kwargs):
  924. hidden_states = features[:, 0, :] # take <s> token (equiv. to [CLS])
  925. hidden_states = self.dropout(hidden_states)
  926. hidden_states = self.dense(hidden_states)
  927. hidden_states = torch.tanh(hidden_states)
  928. hidden_states = self.dropout(hidden_states)
  929. hidden_states = self.out_proj(hidden_states)
  930. return hidden_states
  931. @auto_docstring
  932. class IBertForQuestionAnswering(IBertPreTrainedModel):
  933. def __init__(self, config):
  934. super().__init__(config)
  935. self.num_labels = config.num_labels
  936. self.ibert = IBertModel(config, add_pooling_layer=False)
  937. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  938. # Initialize weights and apply final processing
  939. self.post_init()
  940. @auto_docstring
  941. def forward(
  942. self,
  943. input_ids: torch.LongTensor | None = None,
  944. attention_mask: torch.FloatTensor | None = None,
  945. token_type_ids: torch.LongTensor | None = None,
  946. position_ids: torch.LongTensor | None = None,
  947. inputs_embeds: torch.FloatTensor | None = None,
  948. start_positions: torch.LongTensor | None = None,
  949. end_positions: torch.LongTensor | None = None,
  950. output_attentions: bool | None = None,
  951. output_hidden_states: bool | None = None,
  952. return_dict: bool | None = None,
  953. **kwargs,
  954. ) -> QuestionAnsweringModelOutput | tuple[torch.FloatTensor]:
  955. return_dict = return_dict if return_dict is not None else self.config.return_dict
  956. outputs = self.ibert(
  957. input_ids,
  958. attention_mask=attention_mask,
  959. token_type_ids=token_type_ids,
  960. position_ids=position_ids,
  961. inputs_embeds=inputs_embeds,
  962. output_attentions=output_attentions,
  963. output_hidden_states=output_hidden_states,
  964. return_dict=return_dict,
  965. )
  966. sequence_output = outputs[0]
  967. logits = self.qa_outputs(sequence_output)
  968. start_logits, end_logits = logits.split(1, dim=-1)
  969. start_logits = start_logits.squeeze(-1).contiguous()
  970. end_logits = end_logits.squeeze(-1).contiguous()
  971. total_loss = None
  972. if start_positions is not None and end_positions is not None:
  973. # If we are on multi-GPU, split add a dimension
  974. if len(start_positions.size()) > 1:
  975. start_positions = start_positions.squeeze(-1)
  976. if len(end_positions.size()) > 1:
  977. end_positions = end_positions.squeeze(-1)
  978. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  979. ignored_index = start_logits.size(1)
  980. start_positions = start_positions.clamp(0, ignored_index)
  981. end_positions = end_positions.clamp(0, ignored_index)
  982. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  983. start_loss = loss_fct(start_logits, start_positions)
  984. end_loss = loss_fct(end_logits, end_positions)
  985. total_loss = (start_loss + end_loss) / 2
  986. if not return_dict:
  987. output = (start_logits, end_logits) + outputs[2:]
  988. return ((total_loss,) + output) if total_loss is not None else output
  989. return QuestionAnsweringModelOutput(
  990. loss=total_loss,
  991. start_logits=start_logits,
  992. end_logits=end_logits,
  993. hidden_states=outputs.hidden_states,
  994. attentions=outputs.attentions,
  995. )
  996. def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
  997. """
  998. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
  999. are ignored. This is modified from fairseq's *utils.make_positions*.
  1000. Args:
  1001. input_ids (`torch.LongTensor`):
  1002. Indices of input sequence tokens in the vocabulary.
  1003. Returns: torch.Tensor
  1004. """
  1005. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  1006. mask = input_ids.ne(padding_idx).int()
  1007. incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
  1008. return incremental_indices.long() + padding_idx
  1009. __all__ = [
  1010. "IBertForMaskedLM",
  1011. "IBertForMultipleChoice",
  1012. "IBertForQuestionAnswering",
  1013. "IBertForSequenceClassification",
  1014. "IBertForTokenClassification",
  1015. "IBertModel",
  1016. "IBertPreTrainedModel",
  1017. ]