modeling_convbert.py 44 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066
  1. # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch ConvBERT model."""
  15. import math
  16. from collections.abc import Callable
  17. import torch
  18. from torch import nn
  19. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  20. from ... import initialization as init
  21. from ...activations import ACT2FN, get_activation
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import (
  24. BaseModelOutputWithCrossAttentions,
  25. MaskedLMOutput,
  26. MultipleChoiceModelOutput,
  27. QuestionAnsweringModelOutput,
  28. SequenceClassifierOutput,
  29. TokenClassifierOutput,
  30. )
  31. from ...modeling_utils import PreTrainedModel
  32. from ...processing_utils import Unpack
  33. from ...pytorch_utils import apply_chunking_to_forward
  34. from ...utils import (
  35. TransformersKwargs,
  36. auto_docstring,
  37. can_return_tuple,
  38. logging,
  39. )
  40. from ...utils.generic import merge_with_config_defaults
  41. from ...utils.output_capturing import capture_outputs
  42. from .configuration_convbert import ConvBertConfig
  43. logger = logging.get_logger(__name__)
  44. class ConvBertEmbeddings(nn.Module):
  45. """Construct the embeddings from word, position and token_type embeddings."""
  46. def __init__(self, config):
  47. super().__init__()
  48. self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
  49. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
  50. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
  51. self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
  52. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  53. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  54. self.register_buffer(
  55. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  56. )
  57. self.register_buffer(
  58. "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
  59. )
  60. def forward(
  61. self,
  62. input_ids: torch.LongTensor | None = None,
  63. token_type_ids: torch.LongTensor | None = None,
  64. position_ids: torch.LongTensor | None = None,
  65. inputs_embeds: torch.FloatTensor | None = None,
  66. ) -> torch.LongTensor:
  67. if input_ids is not None:
  68. input_shape = input_ids.size()
  69. else:
  70. input_shape = inputs_embeds.size()[:-1]
  71. seq_length = input_shape[1]
  72. if position_ids is None:
  73. position_ids = self.position_ids[:, :seq_length]
  74. # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
  75. # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
  76. # issue #5664
  77. if token_type_ids is None:
  78. if hasattr(self, "token_type_ids"):
  79. buffered_token_type_ids = self.token_type_ids[:, :seq_length]
  80. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
  81. token_type_ids = buffered_token_type_ids_expanded
  82. else:
  83. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  84. if inputs_embeds is None:
  85. inputs_embeds = self.word_embeddings(input_ids)
  86. position_embeddings = self.position_embeddings(position_ids)
  87. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  88. embeddings = inputs_embeds + position_embeddings + token_type_embeddings
  89. embeddings = self.LayerNorm(embeddings)
  90. embeddings = self.dropout(embeddings)
  91. return embeddings
  92. class SeparableConv1D(nn.Module):
  93. """This class implements separable convolution, i.e. a depthwise and a pointwise layer"""
  94. def __init__(self, config, input_filters, output_filters, kernel_size, **kwargs):
  95. super().__init__()
  96. self.depthwise = nn.Conv1d(
  97. input_filters,
  98. input_filters,
  99. kernel_size=kernel_size,
  100. groups=input_filters,
  101. padding=kernel_size // 2,
  102. bias=False,
  103. )
  104. self.pointwise = nn.Conv1d(input_filters, output_filters, kernel_size=1, bias=False)
  105. self.bias = nn.Parameter(torch.zeros(output_filters, 1))
  106. self.depthwise.weight.data.normal_(mean=0.0, std=config.initializer_range)
  107. self.pointwise.weight.data.normal_(mean=0.0, std=config.initializer_range)
  108. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  109. x = self.depthwise(hidden_states)
  110. x = self.pointwise(x)
  111. x += self.bias
  112. return x
  113. class ConvBertSelfAttention(nn.Module):
  114. def __init__(self, config):
  115. super().__init__()
  116. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  117. raise ValueError(
  118. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  119. f"heads ({config.num_attention_heads})"
  120. )
  121. new_num_attention_heads = config.num_attention_heads // config.head_ratio
  122. if new_num_attention_heads < 1:
  123. self.head_ratio = config.num_attention_heads
  124. self.num_attention_heads = 1
  125. else:
  126. self.num_attention_heads = new_num_attention_heads
  127. self.head_ratio = config.head_ratio
  128. self.conv_kernel_size = config.conv_kernel_size
  129. if config.hidden_size % self.num_attention_heads != 0:
  130. raise ValueError("hidden_size should be divisible by num_attention_heads")
  131. self.attention_head_size = (config.hidden_size // self.num_attention_heads) // 2
  132. self.all_head_size = self.num_attention_heads * self.attention_head_size
  133. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  134. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  135. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  136. self.key_conv_attn_layer = SeparableConv1D(
  137. config, config.hidden_size, self.all_head_size, self.conv_kernel_size
  138. )
  139. self.conv_kernel_layer = nn.Linear(self.all_head_size, self.num_attention_heads * self.conv_kernel_size)
  140. self.conv_out_layer = nn.Linear(config.hidden_size, self.all_head_size)
  141. self.unfold = nn.Unfold(
  142. kernel_size=[self.conv_kernel_size, 1], padding=[int((self.conv_kernel_size - 1) / 2), 0]
  143. )
  144. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  145. def forward(
  146. self,
  147. hidden_states: torch.Tensor,
  148. attention_mask: torch.FloatTensor | None = None,
  149. encoder_hidden_states: torch.Tensor | None = None,
  150. **kwargs: Unpack[TransformersKwargs],
  151. ) -> tuple[torch.Tensor, torch.Tensor]:
  152. input_shape = hidden_states.shape[:-1]
  153. hidden_shape = (*input_shape, -1, self.attention_head_size)
  154. # If this is instantiated as a cross-attention module, the keys
  155. # and values come from an encoder; the attention mask needs to be
  156. # such that the encoder's padding tokens are not attended to.
  157. if encoder_hidden_states is not None:
  158. mixed_key_layer = self.key(encoder_hidden_states)
  159. mixed_value_layer = self.value(encoder_hidden_states)
  160. else:
  161. mixed_key_layer = self.key(hidden_states)
  162. mixed_value_layer = self.value(hidden_states)
  163. mixed_key_conv_attn_layer = self.key_conv_attn_layer(hidden_states.transpose(1, 2))
  164. mixed_key_conv_attn_layer = mixed_key_conv_attn_layer.transpose(1, 2)
  165. mixed_query_layer = self.query(hidden_states)
  166. query_layer = mixed_query_layer.view(hidden_shape).transpose(1, 2)
  167. key_layer = mixed_key_layer.view(hidden_shape).transpose(1, 2)
  168. value_layer = mixed_value_layer.view(hidden_shape).transpose(1, 2)
  169. conv_attn_layer = torch.multiply(mixed_key_conv_attn_layer, mixed_query_layer)
  170. conv_kernel_layer = self.conv_kernel_layer(conv_attn_layer)
  171. conv_kernel_layer = torch.reshape(conv_kernel_layer, [-1, self.conv_kernel_size, 1])
  172. conv_kernel_layer = torch.softmax(conv_kernel_layer, dim=1)
  173. conv_out_layer = self.conv_out_layer(hidden_states)
  174. conv_out_layer = torch.reshape(conv_out_layer, [input_shape[0], -1, self.all_head_size])
  175. conv_out_layer = conv_out_layer.transpose(1, 2).contiguous().unsqueeze(-1)
  176. conv_out_layer = nn.functional.unfold(
  177. conv_out_layer,
  178. kernel_size=[self.conv_kernel_size, 1],
  179. dilation=1,
  180. padding=[(self.conv_kernel_size - 1) // 2, 0],
  181. stride=1,
  182. )
  183. conv_out_layer = conv_out_layer.transpose(1, 2).reshape(
  184. input_shape[0], -1, self.all_head_size, self.conv_kernel_size
  185. )
  186. conv_out_layer = torch.reshape(conv_out_layer, [-1, self.attention_head_size, self.conv_kernel_size])
  187. conv_out_layer = torch.matmul(conv_out_layer, conv_kernel_layer)
  188. conv_out_layer = torch.reshape(conv_out_layer, [-1, self.all_head_size])
  189. # Take the dot product between "query" and "key" to get the raw attention scores.
  190. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  191. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  192. if attention_mask is not None:
  193. # Apply the attention mask is (precomputed for all layers in ConvBertModel forward() function)
  194. attention_scores = attention_scores + attention_mask
  195. # Normalize the attention scores to probabilities.
  196. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  197. # This is actually dropping out entire tokens to attend to, which might
  198. # seem a bit unusual, but is taken from the original Transformer paper.
  199. attention_probs = self.dropout(attention_probs)
  200. context_layer = torch.matmul(attention_probs, value_layer)
  201. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  202. conv_out = torch.reshape(
  203. conv_out_layer, [input_shape[0], -1, self.num_attention_heads, self.attention_head_size]
  204. )
  205. context_layer = torch.cat([context_layer, conv_out], 2)
  206. # conv and context
  207. new_context_layer_shape = context_layer.size()[:-2] + (
  208. self.num_attention_heads * self.attention_head_size * 2,
  209. )
  210. context_layer = context_layer.view(*new_context_layer_shape)
  211. return context_layer, attention_probs
  212. class ConvBertSelfOutput(nn.Module):
  213. def __init__(self, config):
  214. super().__init__()
  215. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  216. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  217. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  218. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  219. hidden_states = self.dense(hidden_states)
  220. hidden_states = self.dropout(hidden_states)
  221. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  222. return hidden_states
  223. class ConvBertAttention(nn.Module):
  224. def __init__(self, config):
  225. super().__init__()
  226. self.self = ConvBertSelfAttention(config)
  227. self.output = ConvBertSelfOutput(config)
  228. def forward(
  229. self,
  230. hidden_states: torch.Tensor,
  231. attention_mask: torch.FloatTensor | None = None,
  232. encoder_hidden_states: torch.Tensor | None = None,
  233. **kwargs: Unpack[TransformersKwargs],
  234. ) -> torch.Tensor:
  235. context_layer, _ = self.self(
  236. hidden_states,
  237. attention_mask,
  238. encoder_hidden_states=encoder_hidden_states,
  239. **kwargs,
  240. )
  241. attention_output = self.output(context_layer, hidden_states)
  242. return attention_output
  243. class GroupedLinearLayer(nn.Module):
  244. def __init__(self, input_size, output_size, num_groups):
  245. super().__init__()
  246. self.input_size = input_size
  247. self.output_size = output_size
  248. self.num_groups = num_groups
  249. self.group_in_dim = self.input_size // self.num_groups
  250. self.group_out_dim = self.output_size // self.num_groups
  251. self.weight = nn.Parameter(torch.empty(self.num_groups, self.group_in_dim, self.group_out_dim))
  252. self.bias = nn.Parameter(torch.empty(output_size))
  253. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  254. batch_size = list(hidden_states.size())[0]
  255. x = torch.reshape(hidden_states, [-1, self.num_groups, self.group_in_dim])
  256. x = x.permute(1, 0, 2)
  257. x = torch.matmul(x, self.weight)
  258. x = x.permute(1, 0, 2)
  259. x = torch.reshape(x, [batch_size, -1, self.output_size])
  260. x = x + self.bias
  261. return x
  262. class ConvBertIntermediate(nn.Module):
  263. def __init__(self, config):
  264. super().__init__()
  265. if config.num_groups == 1:
  266. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  267. else:
  268. self.dense = GroupedLinearLayer(
  269. input_size=config.hidden_size, output_size=config.intermediate_size, num_groups=config.num_groups
  270. )
  271. if isinstance(config.hidden_act, str):
  272. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  273. else:
  274. self.intermediate_act_fn = config.hidden_act
  275. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  276. hidden_states = self.dense(hidden_states)
  277. hidden_states = self.intermediate_act_fn(hidden_states)
  278. return hidden_states
  279. class ConvBertOutput(nn.Module):
  280. def __init__(self, config):
  281. super().__init__()
  282. if config.num_groups == 1:
  283. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  284. else:
  285. self.dense = GroupedLinearLayer(
  286. input_size=config.intermediate_size, output_size=config.hidden_size, num_groups=config.num_groups
  287. )
  288. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  289. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  290. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  291. hidden_states = self.dense(hidden_states)
  292. hidden_states = self.dropout(hidden_states)
  293. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  294. return hidden_states
  295. class ConvBertLayer(GradientCheckpointingLayer):
  296. def __init__(self, config):
  297. super().__init__()
  298. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  299. self.seq_len_dim = 1
  300. self.attention = ConvBertAttention(config)
  301. self.is_decoder = config.is_decoder
  302. self.add_cross_attention = config.add_cross_attention
  303. if self.add_cross_attention:
  304. if not self.is_decoder:
  305. raise TypeError(f"{self} should be used as a decoder model if cross attention is added")
  306. self.crossattention = ConvBertAttention(config)
  307. self.intermediate = ConvBertIntermediate(config)
  308. self.output = ConvBertOutput(config)
  309. def forward(
  310. self,
  311. hidden_states: torch.Tensor,
  312. attention_mask: torch.FloatTensor | None = None,
  313. encoder_hidden_states: torch.Tensor | None = None,
  314. encoder_attention_mask: torch.Tensor | None = None,
  315. **kwargs: Unpack[TransformersKwargs],
  316. ) -> torch.Tensor:
  317. attention_output = self.attention(
  318. hidden_states,
  319. attention_mask,
  320. **kwargs,
  321. )
  322. if self.is_decoder and encoder_hidden_states is not None:
  323. if not hasattr(self, "crossattention"):
  324. raise AttributeError(
  325. f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
  326. " by setting `config.add_cross_attention=True`"
  327. )
  328. attention_output = self.crossattention(
  329. attention_output,
  330. encoder_attention_mask,
  331. encoder_hidden_states=encoder_hidden_states,
  332. **kwargs,
  333. )
  334. layer_output = apply_chunking_to_forward(
  335. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  336. )
  337. return layer_output
  338. def feed_forward_chunk(self, attention_output):
  339. intermediate_output = self.intermediate(attention_output)
  340. layer_output = self.output(intermediate_output, attention_output)
  341. return layer_output
  342. @auto_docstring
  343. class ConvBertPreTrainedModel(PreTrainedModel):
  344. config: ConvBertConfig
  345. base_model_prefix = "convbert"
  346. supports_gradient_checkpointing = True
  347. _can_record_outputs = {
  348. "hidden_states": ConvBertLayer,
  349. "attentions": ConvBertSelfAttention,
  350. }
  351. @torch.no_grad()
  352. def _init_weights(self, module):
  353. """Initialize the weights"""
  354. super()._init_weights(module)
  355. if isinstance(module, SeparableConv1D):
  356. init.zeros_(module.bias)
  357. elif isinstance(module, GroupedLinearLayer):
  358. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  359. init.zeros_(module.bias)
  360. elif isinstance(module, ConvBertEmbeddings):
  361. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  362. init.zeros_(module.token_type_ids)
  363. class ConvBertEncoder(nn.Module):
  364. def __init__(self, config):
  365. super().__init__()
  366. self.config = config
  367. self.layer = nn.ModuleList([ConvBertLayer(config) for _ in range(config.num_hidden_layers)])
  368. self.gradient_checkpointing = False
  369. def forward(
  370. self,
  371. hidden_states: torch.Tensor,
  372. attention_mask: torch.FloatTensor | None = None,
  373. encoder_hidden_states: torch.Tensor | None = None,
  374. encoder_attention_mask: torch.Tensor | None = None,
  375. **kwargs,
  376. ) -> BaseModelOutputWithCrossAttentions:
  377. for layer_module in self.layer:
  378. hidden_states = layer_module(
  379. hidden_states,
  380. attention_mask,
  381. encoder_hidden_states=encoder_hidden_states,
  382. encoder_attention_mask=encoder_attention_mask,
  383. **kwargs,
  384. )
  385. return BaseModelOutputWithCrossAttentions(
  386. last_hidden_state=hidden_states,
  387. )
  388. class ConvBertPredictionHeadTransform(nn.Module):
  389. def __init__(self, config):
  390. super().__init__()
  391. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  392. if isinstance(config.hidden_act, str):
  393. self.transform_act_fn = ACT2FN[config.hidden_act]
  394. else:
  395. self.transform_act_fn = config.hidden_act
  396. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  397. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  398. hidden_states = self.dense(hidden_states)
  399. hidden_states = self.transform_act_fn(hidden_states)
  400. hidden_states = self.LayerNorm(hidden_states)
  401. return hidden_states
  402. # Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->ConvBert
  403. class ConvBertSequenceSummary(nn.Module):
  404. r"""
  405. Compute a single vector summary of a sequence hidden states.
  406. Args:
  407. config ([`ConvBertConfig`]):
  408. The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
  409. config class of your model for the default values it uses):
  410. - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
  411. - `"last"` -- Take the last token hidden state (like XLNet)
  412. - `"first"` -- Take the first token hidden state (like Bert)
  413. - `"mean"` -- Take the mean of all tokens hidden states
  414. - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
  415. - `"attn"` -- Not implemented now, use multi-head attention
  416. - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
  417. - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
  418. (otherwise to `config.hidden_size`).
  419. - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
  420. another string or `None` will add no activation.
  421. - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
  422. - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
  423. """
  424. def __init__(self, config: ConvBertConfig):
  425. super().__init__()
  426. self.summary_type = getattr(config, "summary_type", "last")
  427. if self.summary_type == "attn":
  428. # We should use a standard multi-head attention module with absolute positional embedding for that.
  429. # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
  430. # We can probably just use the multi-head attention module of PyTorch >=1.1.0
  431. raise NotImplementedError
  432. self.summary = nn.Identity()
  433. if hasattr(config, "summary_use_proj") and config.summary_use_proj:
  434. if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
  435. num_classes = config.num_labels
  436. else:
  437. num_classes = config.hidden_size
  438. self.summary = nn.Linear(config.hidden_size, num_classes)
  439. activation_string = getattr(config, "summary_activation", None)
  440. self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
  441. self.first_dropout = nn.Identity()
  442. if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
  443. self.first_dropout = nn.Dropout(config.summary_first_dropout)
  444. self.last_dropout = nn.Identity()
  445. if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
  446. self.last_dropout = nn.Dropout(config.summary_last_dropout)
  447. def forward(
  448. self, hidden_states: torch.FloatTensor, cls_index: torch.LongTensor | None = None
  449. ) -> torch.FloatTensor:
  450. """
  451. Compute a single vector summary of a sequence hidden states.
  452. Args:
  453. hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
  454. The hidden states of the last layer.
  455. cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
  456. Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
  457. Returns:
  458. `torch.FloatTensor`: The summary of the sequence hidden states.
  459. """
  460. if self.summary_type == "last":
  461. output = hidden_states[:, -1]
  462. elif self.summary_type == "first":
  463. output = hidden_states[:, 0]
  464. elif self.summary_type == "mean":
  465. output = hidden_states.mean(dim=1)
  466. elif self.summary_type == "cls_index":
  467. if cls_index is None:
  468. cls_index = torch.full_like(
  469. hidden_states[..., :1, :],
  470. hidden_states.shape[-2] - 1,
  471. dtype=torch.long,
  472. )
  473. else:
  474. cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
  475. cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
  476. # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
  477. output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
  478. elif self.summary_type == "attn":
  479. raise NotImplementedError
  480. output = self.first_dropout(output)
  481. output = self.summary(output)
  482. output = self.activation(output)
  483. output = self.last_dropout(output)
  484. return output
  485. @auto_docstring
  486. class ConvBertModel(ConvBertPreTrainedModel):
  487. def __init__(self, config):
  488. super().__init__(config)
  489. self.embeddings = ConvBertEmbeddings(config)
  490. if config.embedding_size != config.hidden_size:
  491. self.embeddings_project = nn.Linear(config.embedding_size, config.hidden_size)
  492. self.encoder = ConvBertEncoder(config)
  493. self.config = config
  494. # Initialize weights and apply final processing
  495. self.post_init()
  496. def get_input_embeddings(self):
  497. return self.embeddings.word_embeddings
  498. def set_input_embeddings(self, value):
  499. self.embeddings.word_embeddings = value
  500. @merge_with_config_defaults
  501. @capture_outputs
  502. @auto_docstring
  503. def forward(
  504. self,
  505. input_ids: torch.LongTensor | None = None,
  506. attention_mask: torch.FloatTensor | None = None,
  507. token_type_ids: torch.LongTensor | None = None,
  508. position_ids: torch.LongTensor | None = None,
  509. inputs_embeds: torch.FloatTensor | None = None,
  510. **kwargs: Unpack[TransformersKwargs],
  511. ) -> BaseModelOutputWithCrossAttentions:
  512. if input_ids is not None and inputs_embeds is not None:
  513. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  514. elif input_ids is not None:
  515. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  516. input_shape = input_ids.size()
  517. elif inputs_embeds is not None:
  518. input_shape = inputs_embeds.size()[:-1]
  519. else:
  520. raise ValueError("You have to specify either input_ids or inputs_embeds")
  521. batch_size, seq_length = input_shape
  522. device = input_ids.device if input_ids is not None else inputs_embeds.device
  523. if attention_mask is None:
  524. attention_mask = torch.ones(input_shape, device=device)
  525. if token_type_ids is None:
  526. if hasattr(self.embeddings, "token_type_ids"):
  527. buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
  528. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
  529. token_type_ids = buffered_token_type_ids_expanded
  530. else:
  531. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  532. extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
  533. hidden_states = self.embeddings(
  534. input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
  535. )
  536. if hasattr(self, "embeddings_project"):
  537. hidden_states = self.embeddings_project(hidden_states)
  538. encoder_outputs: BaseModelOutputWithCrossAttentions = self.encoder(
  539. hidden_states,
  540. attention_mask=extended_attention_mask,
  541. **kwargs,
  542. )
  543. return encoder_outputs
  544. class ConvBertGeneratorPredictions(nn.Module):
  545. """Prediction module for the generator, made up of two dense layers."""
  546. def __init__(self, config):
  547. super().__init__()
  548. self.activation = get_activation("gelu")
  549. self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
  550. self.dense = nn.Linear(config.hidden_size, config.embedding_size)
  551. def forward(self, generator_hidden_states: torch.FloatTensor) -> torch.FloatTensor:
  552. hidden_states = self.dense(generator_hidden_states)
  553. hidden_states = self.activation(hidden_states)
  554. hidden_states = self.LayerNorm(hidden_states)
  555. return hidden_states
  556. @auto_docstring
  557. class ConvBertForMaskedLM(ConvBertPreTrainedModel):
  558. _tied_weights_keys = {"generator_lm_head.weight": "convbert.embeddings.word_embeddings.weight"}
  559. def __init__(self, config):
  560. super().__init__(config)
  561. self.convbert = ConvBertModel(config)
  562. self.generator_predictions = ConvBertGeneratorPredictions(config)
  563. self.generator_lm_head = nn.Linear(config.embedding_size, config.vocab_size)
  564. # Initialize weights and apply final processing
  565. self.post_init()
  566. def get_output_embeddings(self):
  567. return self.generator_lm_head
  568. def set_output_embeddings(self, word_embeddings):
  569. self.generator_lm_head = word_embeddings
  570. @can_return_tuple
  571. @auto_docstring
  572. def forward(
  573. self,
  574. input_ids: torch.LongTensor | None = None,
  575. attention_mask: torch.FloatTensor | None = None,
  576. token_type_ids: torch.LongTensor | None = None,
  577. position_ids: torch.LongTensor | None = None,
  578. inputs_embeds: torch.FloatTensor | None = None,
  579. labels: torch.LongTensor | None = None,
  580. **kwargs: Unpack[TransformersKwargs],
  581. ) -> tuple | MaskedLMOutput:
  582. r"""
  583. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  584. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  585. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  586. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  587. """
  588. generator_hidden_states: BaseModelOutputWithCrossAttentions = self.convbert(
  589. input_ids,
  590. attention_mask=attention_mask,
  591. token_type_ids=token_type_ids,
  592. position_ids=position_ids,
  593. inputs_embeds=inputs_embeds,
  594. **kwargs,
  595. )
  596. generator_sequence_output = generator_hidden_states[0]
  597. prediction_scores = self.generator_predictions(generator_sequence_output)
  598. prediction_scores = self.generator_lm_head(prediction_scores)
  599. loss = None
  600. # Masked language modeling softmax layer
  601. if labels is not None:
  602. loss_fct = nn.CrossEntropyLoss() # -100 index = padding token
  603. loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  604. return MaskedLMOutput(
  605. loss=loss,
  606. logits=prediction_scores,
  607. hidden_states=generator_hidden_states.hidden_states,
  608. attentions=generator_hidden_states.attentions,
  609. )
  610. class ConvBertClassificationHead(nn.Module):
  611. """Head for sentence-level classification tasks."""
  612. def __init__(self, config):
  613. super().__init__()
  614. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  615. classifier_dropout = (
  616. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  617. )
  618. self.dropout = nn.Dropout(classifier_dropout)
  619. self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
  620. self.config = config
  621. def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
  622. x = hidden_states[:, 0, :] # take <s> token (equiv. to [CLS])
  623. x = self.dropout(x)
  624. x = self.dense(x)
  625. x = ACT2FN[self.config.hidden_act](x)
  626. x = self.dropout(x)
  627. x = self.out_proj(x)
  628. return x
  629. @auto_docstring(
  630. custom_intro="""
  631. ConvBERT Model transformer with a sequence classification/regression head on top (a linear layer on top of the
  632. pooled output) e.g. for GLUE tasks.
  633. """
  634. )
  635. class ConvBertForSequenceClassification(ConvBertPreTrainedModel):
  636. def __init__(self, config):
  637. super().__init__(config)
  638. self.num_labels = config.num_labels
  639. self.config = config
  640. self.convbert = ConvBertModel(config)
  641. self.classifier = ConvBertClassificationHead(config)
  642. # Initialize weights and apply final processing
  643. self.post_init()
  644. @can_return_tuple
  645. @auto_docstring
  646. def forward(
  647. self,
  648. input_ids: torch.LongTensor | None = None,
  649. attention_mask: torch.FloatTensor | None = None,
  650. token_type_ids: torch.LongTensor | None = None,
  651. position_ids: torch.LongTensor | None = None,
  652. inputs_embeds: torch.FloatTensor | None = None,
  653. labels: torch.LongTensor | None = None,
  654. **kwargs: Unpack[TransformersKwargs],
  655. ) -> tuple | SequenceClassifierOutput:
  656. r"""
  657. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  658. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  659. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  660. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  661. """
  662. outputs: BaseModelOutputWithCrossAttentions = self.convbert(
  663. input_ids,
  664. attention_mask=attention_mask,
  665. token_type_ids=token_type_ids,
  666. position_ids=position_ids,
  667. inputs_embeds=inputs_embeds,
  668. **kwargs,
  669. )
  670. sequence_output = outputs[0]
  671. logits = self.classifier(sequence_output)
  672. loss = None
  673. if labels is not None:
  674. if self.config.problem_type is None:
  675. if self.num_labels == 1:
  676. self.config.problem_type = "regression"
  677. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  678. self.config.problem_type = "single_label_classification"
  679. else:
  680. self.config.problem_type = "multi_label_classification"
  681. if self.config.problem_type == "regression":
  682. loss_fct = MSELoss()
  683. if self.num_labels == 1:
  684. loss = loss_fct(logits.squeeze(), labels.squeeze())
  685. else:
  686. loss = loss_fct(logits, labels)
  687. elif self.config.problem_type == "single_label_classification":
  688. loss_fct = CrossEntropyLoss()
  689. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  690. elif self.config.problem_type == "multi_label_classification":
  691. loss_fct = BCEWithLogitsLoss()
  692. loss = loss_fct(logits, labels)
  693. return SequenceClassifierOutput(
  694. loss=loss,
  695. logits=logits,
  696. hidden_states=outputs.hidden_states,
  697. attentions=outputs.attentions,
  698. )
  699. @auto_docstring
  700. class ConvBertForMultipleChoice(ConvBertPreTrainedModel):
  701. def __init__(self, config):
  702. super().__init__(config)
  703. self.convbert = ConvBertModel(config)
  704. self.sequence_summary = ConvBertSequenceSummary(config)
  705. self.classifier = nn.Linear(config.hidden_size, 1)
  706. # Initialize weights and apply final processing
  707. self.post_init()
  708. @can_return_tuple
  709. @auto_docstring
  710. def forward(
  711. self,
  712. input_ids: torch.LongTensor | None = None,
  713. attention_mask: torch.FloatTensor | None = None,
  714. token_type_ids: torch.LongTensor | None = None,
  715. position_ids: torch.LongTensor | None = None,
  716. inputs_embeds: torch.FloatTensor | None = None,
  717. labels: torch.LongTensor | None = None,
  718. **kwargs: Unpack[TransformersKwargs],
  719. ) -> tuple | MultipleChoiceModelOutput:
  720. r"""
  721. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  722. Indices of input sequence tokens in the vocabulary.
  723. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  724. [`PreTrainedTokenizer.__call__`] for details.
  725. [What are input IDs?](../glossary#input-ids)
  726. token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  727. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  728. 1]`:
  729. - 0 corresponds to a *sentence A* token,
  730. - 1 corresponds to a *sentence B* token.
  731. [What are token type IDs?](../glossary#token-type-ids)
  732. position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  733. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  734. config.max_position_embeddings - 1]`.
  735. [What are position IDs?](../glossary#position-ids)
  736. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  737. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  738. is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
  739. model's internal embedding lookup matrix.
  740. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  741. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  742. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  743. `input_ids` above)
  744. """
  745. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  746. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  747. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  748. token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  749. position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  750. inputs_embeds = (
  751. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  752. if inputs_embeds is not None
  753. else None
  754. )
  755. outputs: BaseModelOutputWithCrossAttentions = self.convbert(
  756. input_ids,
  757. attention_mask=attention_mask,
  758. token_type_ids=token_type_ids,
  759. position_ids=position_ids,
  760. inputs_embeds=inputs_embeds,
  761. **kwargs,
  762. )
  763. sequence_output = outputs[0]
  764. pooled_output = self.sequence_summary(sequence_output)
  765. logits = self.classifier(pooled_output)
  766. reshaped_logits = logits.view(-1, num_choices)
  767. loss = None
  768. if labels is not None:
  769. loss_fct = CrossEntropyLoss()
  770. loss = loss_fct(reshaped_logits, labels)
  771. return MultipleChoiceModelOutput(
  772. loss=loss,
  773. logits=reshaped_logits,
  774. hidden_states=outputs.hidden_states,
  775. attentions=outputs.attentions,
  776. )
  777. @auto_docstring
  778. class ConvBertForTokenClassification(ConvBertPreTrainedModel):
  779. def __init__(self, config):
  780. super().__init__(config)
  781. self.num_labels = config.num_labels
  782. self.convbert = ConvBertModel(config)
  783. classifier_dropout = (
  784. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  785. )
  786. self.dropout = nn.Dropout(classifier_dropout)
  787. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  788. # Initialize weights and apply final processing
  789. self.post_init()
  790. @can_return_tuple
  791. @auto_docstring
  792. def forward(
  793. self,
  794. input_ids: torch.LongTensor | None = None,
  795. attention_mask: torch.FloatTensor | None = None,
  796. token_type_ids: torch.LongTensor | None = None,
  797. position_ids: torch.LongTensor | None = None,
  798. inputs_embeds: torch.FloatTensor | None = None,
  799. labels: torch.LongTensor | None = None,
  800. **kwargs: Unpack[TransformersKwargs],
  801. ) -> tuple | TokenClassifierOutput:
  802. r"""
  803. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  804. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  805. """
  806. outputs: BaseModelOutputWithCrossAttentions = self.convbert(
  807. input_ids,
  808. attention_mask=attention_mask,
  809. token_type_ids=token_type_ids,
  810. position_ids=position_ids,
  811. inputs_embeds=inputs_embeds,
  812. **kwargs,
  813. )
  814. sequence_output = outputs[0]
  815. sequence_output = self.dropout(sequence_output)
  816. logits = self.classifier(sequence_output)
  817. loss = None
  818. if labels is not None:
  819. loss_fct = CrossEntropyLoss()
  820. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  821. return TokenClassifierOutput(
  822. loss=loss,
  823. logits=logits,
  824. hidden_states=outputs.hidden_states,
  825. attentions=outputs.attentions,
  826. )
  827. @auto_docstring
  828. class ConvBertForQuestionAnswering(ConvBertPreTrainedModel):
  829. def __init__(self, config):
  830. super().__init__(config)
  831. self.num_labels = config.num_labels
  832. self.convbert = ConvBertModel(config)
  833. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  834. # Initialize weights and apply final processing
  835. self.post_init()
  836. @can_return_tuple
  837. @auto_docstring
  838. def forward(
  839. self,
  840. input_ids: torch.LongTensor | None = None,
  841. attention_mask: torch.FloatTensor | None = None,
  842. token_type_ids: torch.LongTensor | None = None,
  843. position_ids: torch.LongTensor | None = None,
  844. inputs_embeds: torch.FloatTensor | None = None,
  845. start_positions: torch.LongTensor | None = None,
  846. end_positions: torch.LongTensor | None = None,
  847. **kwargs: Unpack[TransformersKwargs],
  848. ) -> QuestionAnsweringModelOutput:
  849. outputs: BaseModelOutputWithCrossAttentions = self.convbert(
  850. input_ids,
  851. attention_mask=attention_mask,
  852. token_type_ids=token_type_ids,
  853. position_ids=position_ids,
  854. inputs_embeds=inputs_embeds,
  855. **kwargs,
  856. )
  857. sequence_output = outputs[0]
  858. logits = self.qa_outputs(sequence_output)
  859. start_logits, end_logits = logits.split(1, dim=-1)
  860. start_logits = start_logits.squeeze(-1).contiguous()
  861. end_logits = end_logits.squeeze(-1).contiguous()
  862. total_loss = None
  863. if start_positions is not None and end_positions is not None:
  864. # If we are on multi-GPU, split add a dimension
  865. if len(start_positions.size()) > 1:
  866. start_positions = start_positions.squeeze(-1)
  867. if len(end_positions.size()) > 1:
  868. end_positions = end_positions.squeeze(-1)
  869. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  870. ignored_index = start_logits.size(1)
  871. start_positions = start_positions.clamp(0, ignored_index)
  872. end_positions = end_positions.clamp(0, ignored_index)
  873. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  874. start_loss = loss_fct(start_logits, start_positions)
  875. end_loss = loss_fct(end_logits, end_positions)
  876. total_loss = (start_loss + end_loss) / 2
  877. return QuestionAnsweringModelOutput(
  878. loss=total_loss,
  879. start_logits=start_logits,
  880. end_logits=end_logits,
  881. hidden_states=outputs.hidden_states,
  882. attentions=outputs.attentions,
  883. )
  884. __all__ = [
  885. "ConvBertForMaskedLM",
  886. "ConvBertForMultipleChoice",
  887. "ConvBertForQuestionAnswering",
  888. "ConvBertForSequenceClassification",
  889. "ConvBertForTokenClassification",
  890. "ConvBertLayer",
  891. "ConvBertModel",
  892. "ConvBertPreTrainedModel",
  893. ]