modeling_mra.py 52 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332
  1. # Copyright 2023 University of Wisconsin-Madison and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch MRA model."""
  15. import math
  16. import torch
  17. from torch import nn
  18. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  19. from ... import initialization as init
  20. from ...activations import ACT2FN
  21. from ...modeling_layers import GradientCheckpointingLayer
  22. from ...modeling_outputs import (
  23. BaseModelOutputWithCrossAttentions,
  24. MaskedLMOutput,
  25. MultipleChoiceModelOutput,
  26. QuestionAnsweringModelOutput,
  27. SequenceClassifierOutput,
  28. TokenClassifierOutput,
  29. )
  30. from ...modeling_utils import PreTrainedModel
  31. from ...pytorch_utils import apply_chunking_to_forward
  32. from ...utils import (
  33. auto_docstring,
  34. is_cuda_platform,
  35. is_kernels_available,
  36. is_ninja_available,
  37. is_torch_cuda_available,
  38. logging,
  39. )
  40. from .configuration_mra import MraConfig
  41. logger = logging.get_logger(__name__)
  42. mra_cuda_kernel = None
  43. def load_cuda_kernels():
  44. global mra_cuda_kernel
  45. if not is_kernels_available():
  46. raise ImportError("kernels is not installed, please install it with `pip install kernels`")
  47. from ...integrations.hub_kernels import get_kernel
  48. mra_cuda_kernel = get_kernel("kernels-community/mra")
  49. def sparse_max(sparse_qk_prod, indices, query_num_block, key_num_block):
  50. """
  51. Computes maximum values for softmax stability.
  52. """
  53. if len(sparse_qk_prod.size()) != 4:
  54. raise ValueError("sparse_qk_prod must be a 4-dimensional tensor.")
  55. if len(indices.size()) != 2:
  56. raise ValueError("indices must be a 2-dimensional tensor.")
  57. if sparse_qk_prod.size(2) != 32:
  58. raise ValueError("The size of the second dimension of sparse_qk_prod must be 32.")
  59. if sparse_qk_prod.size(3) != 32:
  60. raise ValueError("The size of the third dimension of sparse_qk_prod must be 32.")
  61. index_vals = sparse_qk_prod.max(dim=-2).values.transpose(-1, -2)
  62. index_vals = index_vals.contiguous()
  63. indices = indices.int()
  64. indices = indices.contiguous()
  65. max_vals, max_vals_scatter = mra_cuda_kernel.index_max(index_vals, indices, query_num_block, key_num_block)
  66. max_vals_scatter = max_vals_scatter.transpose(-1, -2)[:, :, None, :]
  67. return max_vals, max_vals_scatter
  68. def sparse_mask(mask, indices, block_size=32):
  69. """
  70. Converts attention mask to a sparse mask for high resolution logits.
  71. """
  72. if len(mask.size()) != 2:
  73. raise ValueError("mask must be a 2-dimensional tensor.")
  74. if len(indices.size()) != 2:
  75. raise ValueError("indices must be a 2-dimensional tensor.")
  76. if mask.shape[0] != indices.shape[0]:
  77. raise ValueError("mask and indices must have the same size in the zero-th dimension.")
  78. batch_size, seq_len = mask.shape
  79. num_block = seq_len // block_size
  80. batch_idx = torch.arange(indices.size(0), dtype=torch.long, device=indices.device)
  81. mask = mask.reshape(batch_size, num_block, block_size)
  82. mask = mask[batch_idx[:, None], (indices % num_block).long(), :]
  83. return mask
  84. def mm_to_sparse(dense_query, dense_key, indices, block_size=32):
  85. """
  86. Performs Sampled Dense Matrix Multiplication.
  87. """
  88. batch_size, query_size, dim = dense_query.size()
  89. _, key_size, dim = dense_key.size()
  90. if query_size % block_size != 0:
  91. raise ValueError("query_size (size of first dimension of dense_query) must be divisible by block_size.")
  92. if key_size % block_size != 0:
  93. raise ValueError("key_size (size of first dimension of dense_key) must be divisible by block_size.")
  94. dense_query = dense_query.reshape(batch_size, query_size // block_size, block_size, dim).transpose(-1, -2)
  95. dense_key = dense_key.reshape(batch_size, key_size // block_size, block_size, dim).transpose(-1, -2)
  96. if len(dense_query.size()) != 4:
  97. raise ValueError("dense_query must be a 4-dimensional tensor.")
  98. if len(dense_key.size()) != 4:
  99. raise ValueError("dense_key must be a 4-dimensional tensor.")
  100. if len(indices.size()) != 2:
  101. raise ValueError("indices must be a 2-dimensional tensor.")
  102. if dense_query.size(3) != 32:
  103. raise ValueError("The third dimension of dense_query must be 32.")
  104. if dense_key.size(3) != 32:
  105. raise ValueError("The third dimension of dense_key must be 32.")
  106. dense_query = dense_query.contiguous()
  107. dense_key = dense_key.contiguous()
  108. indices = indices.int()
  109. indices = indices.contiguous()
  110. return mra_cuda_kernel.mm_to_sparse(dense_query, dense_key, indices.int())
  111. def sparse_dense_mm(sparse_query, indices, dense_key, query_num_block, block_size=32):
  112. """
  113. Performs matrix multiplication of a sparse matrix with a dense matrix.
  114. """
  115. batch_size, key_size, dim = dense_key.size()
  116. if key_size % block_size != 0:
  117. raise ValueError("key_size (size of first dimension of dense_key) must be divisible by block_size.")
  118. if sparse_query.size(2) != block_size:
  119. raise ValueError("The size of the second dimension of sparse_query must be equal to the block_size.")
  120. if sparse_query.size(3) != block_size:
  121. raise ValueError("The size of the third dimension of sparse_query must be equal to the block_size.")
  122. dense_key = dense_key.reshape(batch_size, key_size // block_size, block_size, dim).transpose(-1, -2)
  123. if len(sparse_query.size()) != 4:
  124. raise ValueError("sparse_query must be a 4-dimensional tensor.")
  125. if len(dense_key.size()) != 4:
  126. raise ValueError("dense_key must be a 4-dimensional tensor.")
  127. if len(indices.size()) != 2:
  128. raise ValueError("indices must be a 2-dimensional tensor.")
  129. if dense_key.size(3) != 32:
  130. raise ValueError("The size of the third dimension of dense_key must be 32.")
  131. sparse_query = sparse_query.contiguous()
  132. indices = indices.int()
  133. indices = indices.contiguous()
  134. dense_key = dense_key.contiguous()
  135. dense_qk_prod = mra_cuda_kernel.sparse_dense_mm(sparse_query, indices, dense_key, query_num_block)
  136. dense_qk_prod = dense_qk_prod.transpose(-1, -2).reshape(batch_size, query_num_block * block_size, dim)
  137. return dense_qk_prod
  138. def transpose_indices(indices, dim_1_block, dim_2_block):
  139. return ((indices % dim_2_block) * dim_1_block + torch.div(indices, dim_2_block, rounding_mode="floor")).long()
  140. class MraSampledDenseMatMul(torch.autograd.Function):
  141. @staticmethod
  142. def forward(ctx, dense_query, dense_key, indices, block_size):
  143. sparse_qk_prod = mm_to_sparse(dense_query, dense_key, indices, block_size)
  144. ctx.save_for_backward(dense_query, dense_key, indices)
  145. ctx.block_size = block_size
  146. return sparse_qk_prod
  147. @staticmethod
  148. def backward(ctx, grad):
  149. dense_query, dense_key, indices = ctx.saved_tensors
  150. block_size = ctx.block_size
  151. query_num_block = dense_query.size(1) // block_size
  152. key_num_block = dense_key.size(1) // block_size
  153. indices_T = transpose_indices(indices, query_num_block, key_num_block)
  154. grad_key = sparse_dense_mm(grad.transpose(-1, -2), indices_T, dense_query, key_num_block)
  155. grad_query = sparse_dense_mm(grad, indices, dense_key, query_num_block)
  156. return grad_query, grad_key, None, None
  157. @staticmethod
  158. def operator_call(dense_query, dense_key, indices, block_size=32):
  159. return MraSampledDenseMatMul.apply(dense_query, dense_key, indices, block_size)
  160. class MraSparseDenseMatMul(torch.autograd.Function):
  161. @staticmethod
  162. def forward(ctx, sparse_query, indices, dense_key, query_num_block):
  163. sparse_qk_prod = sparse_dense_mm(sparse_query, indices, dense_key, query_num_block)
  164. ctx.save_for_backward(sparse_query, indices, dense_key)
  165. ctx.query_num_block = query_num_block
  166. return sparse_qk_prod
  167. @staticmethod
  168. def backward(ctx, grad):
  169. sparse_query, indices, dense_key = ctx.saved_tensors
  170. query_num_block = ctx.query_num_block
  171. key_num_block = dense_key.size(1) // sparse_query.size(-1)
  172. indices_T = transpose_indices(indices, query_num_block, key_num_block)
  173. grad_key = sparse_dense_mm(sparse_query.transpose(-1, -2), indices_T, grad, key_num_block)
  174. grad_query = mm_to_sparse(grad, dense_key, indices)
  175. return grad_query, None, grad_key, None
  176. @staticmethod
  177. def operator_call(sparse_query, indices, dense_key, query_num_block):
  178. return MraSparseDenseMatMul.apply(sparse_query, indices, dense_key, query_num_block)
  179. class MraReduceSum:
  180. @staticmethod
  181. def operator_call(sparse_query, indices, query_num_block, key_num_block):
  182. batch_size, num_block, block_size, _ = sparse_query.size()
  183. if len(sparse_query.size()) != 4:
  184. raise ValueError("sparse_query must be a 4-dimensional tensor.")
  185. if len(indices.size()) != 2:
  186. raise ValueError("indices must be a 2-dimensional tensor.")
  187. _, _, block_size, _ = sparse_query.size()
  188. batch_size, num_block = indices.size()
  189. sparse_query = sparse_query.sum(dim=2).reshape(batch_size * num_block, block_size)
  190. batch_idx = torch.arange(indices.size(0), dtype=torch.long, device=indices.device)
  191. global_idxes = (
  192. torch.div(indices, key_num_block, rounding_mode="floor").long() + batch_idx[:, None] * query_num_block
  193. ).reshape(batch_size * num_block)
  194. temp = torch.zeros(
  195. (batch_size * query_num_block, block_size), dtype=sparse_query.dtype, device=sparse_query.device
  196. )
  197. output = temp.index_add(0, global_idxes, sparse_query).reshape(batch_size, query_num_block, block_size)
  198. output = output.reshape(batch_size, query_num_block * block_size)
  199. return output
  200. def get_low_resolution_logit(query, key, block_size, mask=None, value=None):
  201. """
  202. Compute low resolution approximation.
  203. """
  204. batch_size, seq_len, head_dim = query.size()
  205. num_block_per_row = seq_len // block_size
  206. value_hat = None
  207. if mask is not None:
  208. token_count = mask.reshape(batch_size, num_block_per_row, block_size).sum(dim=-1)
  209. query_hat = query.reshape(batch_size, num_block_per_row, block_size, head_dim).sum(dim=-2) / (
  210. token_count[:, :, None] + 1e-6
  211. )
  212. key_hat = key.reshape(batch_size, num_block_per_row, block_size, head_dim).sum(dim=-2) / (
  213. token_count[:, :, None] + 1e-6
  214. )
  215. if value is not None:
  216. value_hat = value.reshape(batch_size, num_block_per_row, block_size, head_dim).sum(dim=-2) / (
  217. token_count[:, :, None] + 1e-6
  218. )
  219. else:
  220. token_count = block_size * torch.ones(batch_size, num_block_per_row, dtype=torch.float, device=query.device)
  221. query_hat = query.reshape(batch_size, num_block_per_row, block_size, head_dim).mean(dim=-2)
  222. key_hat = key.reshape(batch_size, num_block_per_row, block_size, head_dim).mean(dim=-2)
  223. if value is not None:
  224. value_hat = value.reshape(batch_size, num_block_per_row, block_size, head_dim).mean(dim=-2)
  225. low_resolution_logit = torch.matmul(query_hat, key_hat.transpose(-1, -2)) / math.sqrt(head_dim)
  226. low_resolution_logit_row_max = low_resolution_logit.max(dim=-1, keepdims=True).values
  227. if mask is not None:
  228. low_resolution_logit = (
  229. low_resolution_logit - 1e4 * ((token_count[:, None, :] * token_count[:, :, None]) < 0.5).float()
  230. )
  231. return low_resolution_logit, token_count, low_resolution_logit_row_max, value_hat
  232. def get_block_idxes(
  233. low_resolution_logit, num_blocks, approx_mode, initial_prior_first_n_blocks, initial_prior_diagonal_n_blocks
  234. ):
  235. """
  236. Compute the indices of the subset of components to be used in the approximation.
  237. """
  238. batch_size, total_blocks_per_row, _ = low_resolution_logit.shape
  239. if initial_prior_diagonal_n_blocks > 0:
  240. offset = initial_prior_diagonal_n_blocks // 2
  241. temp_mask = torch.ones(total_blocks_per_row, total_blocks_per_row, device=low_resolution_logit.device)
  242. diagonal_mask = torch.tril(torch.triu(temp_mask, diagonal=-offset), diagonal=offset)
  243. low_resolution_logit = low_resolution_logit + diagonal_mask[None, :, :] * 5e3
  244. if initial_prior_first_n_blocks > 0:
  245. low_resolution_logit[:, :initial_prior_first_n_blocks, :] = (
  246. low_resolution_logit[:, :initial_prior_first_n_blocks, :] + 5e3
  247. )
  248. low_resolution_logit[:, :, :initial_prior_first_n_blocks] = (
  249. low_resolution_logit[:, :, :initial_prior_first_n_blocks] + 5e3
  250. )
  251. top_k_vals = torch.topk(
  252. low_resolution_logit.reshape(batch_size, -1), num_blocks, dim=-1, largest=True, sorted=False
  253. )
  254. indices = top_k_vals.indices
  255. if approx_mode == "full":
  256. threshold = top_k_vals.values.min(dim=-1).values
  257. high_resolution_mask = (low_resolution_logit >= threshold[:, None, None]).float()
  258. elif approx_mode == "sparse":
  259. high_resolution_mask = None
  260. else:
  261. raise ValueError(f"{approx_mode} is not a valid approx_model value.")
  262. return indices, high_resolution_mask
  263. def mra2_attention(
  264. query,
  265. key,
  266. value,
  267. mask,
  268. num_blocks,
  269. approx_mode,
  270. block_size=32,
  271. initial_prior_first_n_blocks=0,
  272. initial_prior_diagonal_n_blocks=0,
  273. ):
  274. """
  275. Use Mra to approximate self-attention.
  276. """
  277. if mra_cuda_kernel is None:
  278. return torch.zeros_like(query).requires_grad_()
  279. batch_size, num_head, seq_len, head_dim = query.size()
  280. meta_batch = batch_size * num_head
  281. if seq_len % block_size != 0:
  282. raise ValueError("sequence length must be divisible by the block_size.")
  283. num_block_per_row = seq_len // block_size
  284. query = query.reshape(meta_batch, seq_len, head_dim)
  285. key = key.reshape(meta_batch, seq_len, head_dim)
  286. value = value.reshape(meta_batch, seq_len, head_dim)
  287. if mask is not None:
  288. query = query * mask[:, :, None]
  289. key = key * mask[:, :, None]
  290. value = value * mask[:, :, None]
  291. if approx_mode == "full":
  292. low_resolution_logit, token_count, low_resolution_logit_row_max, value_hat = get_low_resolution_logit(
  293. query, key, block_size, mask, value
  294. )
  295. elif approx_mode == "sparse":
  296. with torch.no_grad():
  297. low_resolution_logit, token_count, low_resolution_logit_row_max, _ = get_low_resolution_logit(
  298. query, key, block_size, mask
  299. )
  300. else:
  301. raise Exception('approx_mode must be "full" or "sparse"')
  302. with torch.no_grad():
  303. low_resolution_logit_normalized = low_resolution_logit - low_resolution_logit_row_max
  304. indices, high_resolution_mask = get_block_idxes(
  305. low_resolution_logit_normalized,
  306. num_blocks,
  307. approx_mode,
  308. initial_prior_first_n_blocks,
  309. initial_prior_diagonal_n_blocks,
  310. )
  311. high_resolution_logit = MraSampledDenseMatMul.operator_call(
  312. query, key, indices, block_size=block_size
  313. ) / math.sqrt(head_dim)
  314. max_vals, max_vals_scatter = sparse_max(high_resolution_logit, indices, num_block_per_row, num_block_per_row)
  315. high_resolution_logit = high_resolution_logit - max_vals_scatter
  316. if mask is not None:
  317. high_resolution_logit = high_resolution_logit - 1e4 * (1 - sparse_mask(mask, indices)[:, :, :, None])
  318. high_resolution_attn = torch.exp(high_resolution_logit)
  319. high_resolution_attn_out = MraSparseDenseMatMul.operator_call(
  320. high_resolution_attn, indices, value, num_block_per_row
  321. )
  322. high_resolution_normalizer = MraReduceSum.operator_call(
  323. high_resolution_attn, indices, num_block_per_row, num_block_per_row
  324. )
  325. if approx_mode == "full":
  326. low_resolution_attn = (
  327. torch.exp(low_resolution_logit - low_resolution_logit_row_max - 1e4 * high_resolution_mask)
  328. * token_count[:, None, :]
  329. )
  330. low_resolution_attn_out = (
  331. torch.matmul(low_resolution_attn, value_hat)[:, :, None, :]
  332. .repeat(1, 1, block_size, 1)
  333. .reshape(meta_batch, seq_len, head_dim)
  334. )
  335. low_resolution_normalizer = (
  336. low_resolution_attn.sum(dim=-1)[:, :, None].repeat(1, 1, block_size).reshape(meta_batch, seq_len)
  337. )
  338. log_correction = low_resolution_logit_row_max.repeat(1, 1, block_size).reshape(meta_batch, seq_len) - max_vals
  339. if mask is not None:
  340. log_correction = log_correction * mask
  341. low_resolution_corr = torch.exp(log_correction * (log_correction <= 0).float())
  342. low_resolution_attn_out = low_resolution_attn_out * low_resolution_corr[:, :, None]
  343. low_resolution_normalizer = low_resolution_normalizer * low_resolution_corr
  344. high_resolution_corr = torch.exp(-log_correction * (log_correction > 0).float())
  345. high_resolution_attn_out = high_resolution_attn_out * high_resolution_corr[:, :, None]
  346. high_resolution_normalizer = high_resolution_normalizer * high_resolution_corr
  347. context_layer = (high_resolution_attn_out + low_resolution_attn_out) / (
  348. high_resolution_normalizer[:, :, None] + low_resolution_normalizer[:, :, None] + 1e-6
  349. )
  350. elif approx_mode == "sparse":
  351. context_layer = high_resolution_attn_out / (high_resolution_normalizer[:, :, None] + 1e-6)
  352. else:
  353. raise Exception('config.approx_mode must be "full" or "sparse"')
  354. if mask is not None:
  355. context_layer = context_layer * mask[:, :, None]
  356. context_layer = context_layer.reshape(batch_size, num_head, seq_len, head_dim)
  357. return context_layer
  358. class MraEmbeddings(nn.Module):
  359. """Construct the embeddings from word, position and token_type embeddings."""
  360. def __init__(self, config):
  361. super().__init__()
  362. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  363. self.position_embeddings = nn.Embedding(config.max_position_embeddings + 2, config.hidden_size)
  364. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  365. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  366. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  367. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  368. self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + 2)
  369. self.register_buffer(
  370. "token_type_ids",
  371. torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
  372. persistent=False,
  373. )
  374. def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
  375. if input_ids is not None:
  376. input_shape = input_ids.size()
  377. else:
  378. input_shape = inputs_embeds.size()[:-1]
  379. seq_length = input_shape[1]
  380. if position_ids is None:
  381. position_ids = self.position_ids[:, :seq_length]
  382. # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
  383. # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
  384. # issue #5664
  385. if token_type_ids is None:
  386. if hasattr(self, "token_type_ids"):
  387. buffered_token_type_ids = self.token_type_ids[:, :seq_length]
  388. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
  389. token_type_ids = buffered_token_type_ids_expanded
  390. else:
  391. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  392. if inputs_embeds is None:
  393. inputs_embeds = self.word_embeddings(input_ids)
  394. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  395. embeddings = inputs_embeds + token_type_embeddings
  396. position_embeddings = self.position_embeddings(position_ids)
  397. embeddings += position_embeddings
  398. embeddings = self.LayerNorm(embeddings)
  399. embeddings = self.dropout(embeddings)
  400. return embeddings
  401. class MraSelfAttention(nn.Module):
  402. def __init__(self, config):
  403. super().__init__()
  404. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  405. raise ValueError(
  406. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  407. f"heads ({config.num_attention_heads})"
  408. )
  409. kernel_loaded = mra_cuda_kernel is not None
  410. if is_torch_cuda_available() and is_cuda_platform() and is_ninja_available() and not kernel_loaded:
  411. try:
  412. load_cuda_kernels()
  413. except Exception as e:
  414. logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
  415. self.num_attention_heads = config.num_attention_heads
  416. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  417. self.all_head_size = self.num_attention_heads * self.attention_head_size
  418. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  419. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  420. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  421. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  422. self.num_block = (config.max_position_embeddings // 32) * config.block_per_row
  423. self.num_block = min(self.num_block, int((config.max_position_embeddings // 32) ** 2))
  424. self.approx_mode = config.approx_mode
  425. self.initial_prior_first_n_blocks = config.initial_prior_first_n_blocks
  426. self.initial_prior_diagonal_n_blocks = config.initial_prior_diagonal_n_blocks
  427. def forward(self, hidden_states, attention_mask=None):
  428. batch_size, seq_len, _ = hidden_states.shape
  429. query_layer = (
  430. self.query(hidden_states)
  431. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  432. .transpose(1, 2)
  433. )
  434. key_layer = (
  435. self.key(hidden_states)
  436. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  437. .transpose(1, 2)
  438. )
  439. value_layer = (
  440. self.value(hidden_states)
  441. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  442. .transpose(1, 2)
  443. )
  444. # revert changes made by get_extended_attention_mask
  445. attention_mask = 1.0 + attention_mask / 10000.0
  446. attention_mask = (
  447. attention_mask.squeeze()
  448. .repeat(1, self.num_attention_heads, 1)
  449. .reshape(batch_size * self.num_attention_heads, seq_len)
  450. .int()
  451. )
  452. # The CUDA kernels are most efficient with inputs whose size is a multiple of a GPU's warp size (32). Inputs
  453. # smaller than this are padded with zeros.
  454. gpu_warp_size = 32
  455. if self.attention_head_size < gpu_warp_size:
  456. pad_size = batch_size, self.num_attention_heads, seq_len, gpu_warp_size - self.attention_head_size
  457. query_layer = torch.cat([query_layer, torch.zeros(pad_size, device=query_layer.device)], dim=-1)
  458. key_layer = torch.cat([key_layer, torch.zeros(pad_size, device=key_layer.device)], dim=-1)
  459. value_layer = torch.cat([value_layer, torch.zeros(pad_size, device=value_layer.device)], dim=-1)
  460. context_layer = mra2_attention(
  461. query_layer.float(),
  462. key_layer.float(),
  463. value_layer.float(),
  464. attention_mask.float(),
  465. self.num_block,
  466. approx_mode=self.approx_mode,
  467. initial_prior_first_n_blocks=self.initial_prior_first_n_blocks,
  468. initial_prior_diagonal_n_blocks=self.initial_prior_diagonal_n_blocks,
  469. )
  470. if self.attention_head_size < gpu_warp_size:
  471. context_layer = context_layer[:, :, :, : self.attention_head_size]
  472. context_layer = context_layer.reshape(batch_size, self.num_attention_heads, seq_len, self.attention_head_size)
  473. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  474. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  475. context_layer = context_layer.view(*new_context_layer_shape)
  476. outputs = (context_layer,)
  477. return outputs
  478. # Copied from transformers.models.bert.modeling_bert.BertSelfOutput
  479. class MraSelfOutput(nn.Module):
  480. def __init__(self, config):
  481. super().__init__()
  482. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  483. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  484. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  485. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  486. hidden_states = self.dense(hidden_states)
  487. hidden_states = self.dropout(hidden_states)
  488. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  489. return hidden_states
  490. class MraAttention(nn.Module):
  491. def __init__(self, config):
  492. super().__init__()
  493. self.self = MraSelfAttention(config)
  494. self.output = MraSelfOutput(config)
  495. def forward(self, hidden_states, attention_mask=None):
  496. self_outputs = self.self(hidden_states, attention_mask)
  497. attention_output = self.output(self_outputs[0], hidden_states)
  498. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  499. return outputs
  500. # Copied from transformers.models.bert.modeling_bert.BertIntermediate
  501. class MraIntermediate(nn.Module):
  502. def __init__(self, config):
  503. super().__init__()
  504. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  505. if isinstance(config.hidden_act, str):
  506. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  507. else:
  508. self.intermediate_act_fn = config.hidden_act
  509. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  510. hidden_states = self.dense(hidden_states)
  511. hidden_states = self.intermediate_act_fn(hidden_states)
  512. return hidden_states
  513. # Copied from transformers.models.bert.modeling_bert.BertOutput
  514. class MraOutput(nn.Module):
  515. def __init__(self, config):
  516. super().__init__()
  517. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  518. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  519. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  520. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  521. hidden_states = self.dense(hidden_states)
  522. hidden_states = self.dropout(hidden_states)
  523. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  524. return hidden_states
  525. class MraLayer(GradientCheckpointingLayer):
  526. def __init__(self, config):
  527. super().__init__()
  528. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  529. self.seq_len_dim = 1
  530. self.attention = MraAttention(config)
  531. self.add_cross_attention = config.add_cross_attention
  532. self.intermediate = MraIntermediate(config)
  533. self.output = MraOutput(config)
  534. def forward(self, hidden_states, attention_mask=None):
  535. self_attention_outputs = self.attention(hidden_states, attention_mask)
  536. attention_output = self_attention_outputs[0]
  537. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  538. layer_output = apply_chunking_to_forward(
  539. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  540. )
  541. outputs = (layer_output,) + outputs
  542. return outputs
  543. def feed_forward_chunk(self, attention_output):
  544. intermediate_output = self.intermediate(attention_output)
  545. layer_output = self.output(intermediate_output, attention_output)
  546. return layer_output
  547. class MraEncoder(nn.Module):
  548. def __init__(self, config):
  549. super().__init__()
  550. self.config = config
  551. self.layer = nn.ModuleList([MraLayer(config) for _ in range(config.num_hidden_layers)])
  552. self.gradient_checkpointing = False
  553. def forward(
  554. self,
  555. hidden_states,
  556. attention_mask=None,
  557. output_hidden_states=False,
  558. return_dict=True,
  559. ):
  560. all_hidden_states = () if output_hidden_states else None
  561. for i, layer_module in enumerate(self.layer):
  562. if output_hidden_states:
  563. all_hidden_states = all_hidden_states + (hidden_states,)
  564. layer_outputs = layer_module(hidden_states, attention_mask)
  565. hidden_states = layer_outputs[0]
  566. if output_hidden_states:
  567. all_hidden_states = all_hidden_states + (hidden_states,)
  568. if not return_dict:
  569. return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
  570. return BaseModelOutputWithCrossAttentions(
  571. last_hidden_state=hidden_states,
  572. hidden_states=all_hidden_states,
  573. )
  574. # Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform
  575. class MraPredictionHeadTransform(nn.Module):
  576. def __init__(self, config):
  577. super().__init__()
  578. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  579. if isinstance(config.hidden_act, str):
  580. self.transform_act_fn = ACT2FN[config.hidden_act]
  581. else:
  582. self.transform_act_fn = config.hidden_act
  583. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  584. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  585. hidden_states = self.dense(hidden_states)
  586. hidden_states = self.transform_act_fn(hidden_states)
  587. hidden_states = self.LayerNorm(hidden_states)
  588. return hidden_states
  589. # Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->Mra
  590. class MraLMPredictionHead(nn.Module):
  591. def __init__(self, config):
  592. super().__init__()
  593. self.transform = MraPredictionHeadTransform(config)
  594. # The output weights are the same as the input embeddings, but there is
  595. # an output-only bias for each token.
  596. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
  597. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  598. def forward(self, hidden_states):
  599. hidden_states = self.transform(hidden_states)
  600. hidden_states = self.decoder(hidden_states)
  601. return hidden_states
  602. # Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->Mra
  603. class MraOnlyMLMHead(nn.Module):
  604. def __init__(self, config):
  605. super().__init__()
  606. self.predictions = MraLMPredictionHead(config)
  607. def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
  608. prediction_scores = self.predictions(sequence_output)
  609. return prediction_scores
  610. @auto_docstring
  611. # Copied from transformers.models.yoso.modeling_yoso.YosoPreTrainedModel with Yoso->Mra,yoso->mra
  612. class MraPreTrainedModel(PreTrainedModel):
  613. config: MraConfig
  614. base_model_prefix = "mra"
  615. supports_gradient_checkpointing = True
  616. @torch.no_grad()
  617. def _init_weights(self, module: nn.Module):
  618. """Initialize the weights"""
  619. super()._init_weights(module)
  620. if isinstance(module, MraLMPredictionHead):
  621. init.zeros_(module.bias)
  622. elif isinstance(module, MraEmbeddings):
  623. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)) + 2)
  624. init.zeros_(module.token_type_ids)
  625. @auto_docstring
  626. class MraModel(MraPreTrainedModel):
  627. def __init__(self, config):
  628. super().__init__(config)
  629. self.config = config
  630. self.embeddings = MraEmbeddings(config)
  631. self.encoder = MraEncoder(config)
  632. # Initialize weights and apply final processing
  633. self.post_init()
  634. def get_input_embeddings(self):
  635. return self.embeddings.word_embeddings
  636. def set_input_embeddings(self, value):
  637. self.embeddings.word_embeddings = value
  638. @auto_docstring
  639. def forward(
  640. self,
  641. input_ids: torch.Tensor | None = None,
  642. attention_mask: torch.Tensor | None = None,
  643. token_type_ids: torch.Tensor | None = None,
  644. position_ids: torch.Tensor | None = None,
  645. inputs_embeds: torch.Tensor | None = None,
  646. output_hidden_states: bool | None = None,
  647. return_dict: bool | None = None,
  648. **kwargs,
  649. ) -> tuple | BaseModelOutputWithCrossAttentions:
  650. output_hidden_states = (
  651. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  652. )
  653. return_dict = return_dict if return_dict is not None else self.config.return_dict
  654. if input_ids is not None and inputs_embeds is not None:
  655. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  656. elif input_ids is not None:
  657. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  658. input_shape = input_ids.size()
  659. elif inputs_embeds is not None:
  660. input_shape = inputs_embeds.size()[:-1]
  661. else:
  662. raise ValueError("You have to specify either input_ids or inputs_embeds")
  663. batch_size, seq_length = input_shape
  664. device = input_ids.device if input_ids is not None else inputs_embeds.device
  665. if attention_mask is None:
  666. attention_mask = torch.ones(((batch_size, seq_length)), device=device)
  667. if token_type_ids is None:
  668. if hasattr(self.embeddings, "token_type_ids"):
  669. buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
  670. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
  671. token_type_ids = buffered_token_type_ids_expanded
  672. else:
  673. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  674. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  675. # ourselves in which case we just need to make it broadcastable to all heads.
  676. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
  677. embedding_output = self.embeddings(
  678. input_ids=input_ids,
  679. position_ids=position_ids,
  680. token_type_ids=token_type_ids,
  681. inputs_embeds=inputs_embeds,
  682. )
  683. encoder_outputs = self.encoder(
  684. embedding_output,
  685. attention_mask=extended_attention_mask,
  686. output_hidden_states=output_hidden_states,
  687. return_dict=return_dict,
  688. )
  689. sequence_output = encoder_outputs[0]
  690. if not return_dict:
  691. return (sequence_output,) + encoder_outputs[1:]
  692. return BaseModelOutputWithCrossAttentions(
  693. last_hidden_state=sequence_output,
  694. hidden_states=encoder_outputs.hidden_states,
  695. attentions=encoder_outputs.attentions,
  696. cross_attentions=encoder_outputs.cross_attentions,
  697. )
  698. @auto_docstring
  699. class MraForMaskedLM(MraPreTrainedModel):
  700. _tied_weights_keys = {
  701. "cls.predictions.decoder.bias": "cls.predictions.bias",
  702. "cls.predictions.decoder.weight": "mra.embeddings.word_embeddings.weight",
  703. }
  704. def __init__(self, config):
  705. super().__init__(config)
  706. self.mra = MraModel(config)
  707. self.cls = MraOnlyMLMHead(config)
  708. # Initialize weights and apply final processing
  709. self.post_init()
  710. def get_output_embeddings(self):
  711. return self.cls.predictions.decoder
  712. def set_output_embeddings(self, new_embeddings):
  713. self.cls.predictions.decoder = new_embeddings
  714. self.cls.predictions.bias = new_embeddings.bias
  715. @auto_docstring
  716. def forward(
  717. self,
  718. input_ids: torch.Tensor | None = None,
  719. attention_mask: torch.Tensor | None = None,
  720. token_type_ids: torch.Tensor | None = None,
  721. position_ids: torch.Tensor | None = None,
  722. inputs_embeds: torch.Tensor | None = None,
  723. labels: torch.Tensor | None = None,
  724. output_hidden_states: bool | None = None,
  725. return_dict: bool | None = None,
  726. **kwargs,
  727. ) -> tuple | MaskedLMOutput:
  728. r"""
  729. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  730. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  731. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  732. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  733. """
  734. return_dict = return_dict if return_dict is not None else self.config.return_dict
  735. outputs = self.mra(
  736. input_ids,
  737. attention_mask=attention_mask,
  738. token_type_ids=token_type_ids,
  739. position_ids=position_ids,
  740. inputs_embeds=inputs_embeds,
  741. output_hidden_states=output_hidden_states,
  742. return_dict=return_dict,
  743. )
  744. sequence_output = outputs[0]
  745. prediction_scores = self.cls(sequence_output)
  746. masked_lm_loss = None
  747. if labels is not None:
  748. loss_fct = CrossEntropyLoss() # -100 index = padding token
  749. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  750. if not return_dict:
  751. output = (prediction_scores,) + outputs[1:]
  752. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  753. return MaskedLMOutput(
  754. loss=masked_lm_loss,
  755. logits=prediction_scores,
  756. hidden_states=outputs.hidden_states,
  757. attentions=outputs.attentions,
  758. )
  759. # Copied from transformers.models.yoso.modeling_yoso.YosoClassificationHead with Yoso->Mra
  760. class MraClassificationHead(nn.Module):
  761. """Head for sentence-level classification tasks."""
  762. def __init__(self, config):
  763. super().__init__()
  764. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  765. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  766. self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
  767. self.config = config
  768. def forward(self, features, **kwargs):
  769. x = features[:, 0, :] # take <s> token (equiv. to [CLS])
  770. x = self.dropout(x)
  771. x = self.dense(x)
  772. x = ACT2FN[self.config.hidden_act](x)
  773. x = self.dropout(x)
  774. x = self.out_proj(x)
  775. return x
  776. @auto_docstring(
  777. custom_intro="""
  778. MRA Model transformer with a sequence classification/regression head on top (a linear layer on top of
  779. the pooled output) e.g. for GLUE tasks.
  780. """
  781. )
  782. class MraForSequenceClassification(MraPreTrainedModel):
  783. def __init__(self, config):
  784. super().__init__(config)
  785. self.num_labels = config.num_labels
  786. self.mra = MraModel(config)
  787. self.classifier = MraClassificationHead(config)
  788. # Initialize weights and apply final processing
  789. self.post_init()
  790. @auto_docstring
  791. def forward(
  792. self,
  793. input_ids: torch.Tensor | None = None,
  794. attention_mask: torch.Tensor | None = None,
  795. token_type_ids: torch.Tensor | None = None,
  796. position_ids: torch.Tensor | None = None,
  797. inputs_embeds: torch.Tensor | None = None,
  798. labels: torch.Tensor | None = None,
  799. output_hidden_states: bool | None = None,
  800. return_dict: bool | None = None,
  801. **kwargs,
  802. ) -> tuple | SequenceClassifierOutput:
  803. r"""
  804. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  805. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  806. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  807. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  808. """
  809. return_dict = return_dict if return_dict is not None else self.config.return_dict
  810. outputs = self.mra(
  811. input_ids,
  812. attention_mask=attention_mask,
  813. token_type_ids=token_type_ids,
  814. position_ids=position_ids,
  815. inputs_embeds=inputs_embeds,
  816. output_hidden_states=output_hidden_states,
  817. return_dict=return_dict,
  818. )
  819. sequence_output = outputs[0]
  820. logits = self.classifier(sequence_output)
  821. loss = None
  822. if labels is not None:
  823. if self.config.problem_type is None:
  824. if self.num_labels == 1:
  825. self.config.problem_type = "regression"
  826. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  827. self.config.problem_type = "single_label_classification"
  828. else:
  829. self.config.problem_type = "multi_label_classification"
  830. if self.config.problem_type == "regression":
  831. loss_fct = MSELoss()
  832. if self.num_labels == 1:
  833. loss = loss_fct(logits.squeeze(), labels.squeeze())
  834. else:
  835. loss = loss_fct(logits, labels)
  836. elif self.config.problem_type == "single_label_classification":
  837. loss_fct = CrossEntropyLoss()
  838. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  839. elif self.config.problem_type == "multi_label_classification":
  840. loss_fct = BCEWithLogitsLoss()
  841. loss = loss_fct(logits, labels)
  842. if not return_dict:
  843. output = (logits,) + outputs[1:]
  844. return ((loss,) + output) if loss is not None else output
  845. return SequenceClassifierOutput(
  846. loss=loss,
  847. logits=logits,
  848. hidden_states=outputs.hidden_states,
  849. attentions=outputs.attentions,
  850. )
  851. @auto_docstring
  852. class MraForMultipleChoice(MraPreTrainedModel):
  853. def __init__(self, config):
  854. super().__init__(config)
  855. self.mra = MraModel(config)
  856. self.pre_classifier = nn.Linear(config.hidden_size, config.hidden_size)
  857. self.classifier = nn.Linear(config.hidden_size, 1)
  858. # Initialize weights and apply final processing
  859. self.post_init()
  860. @auto_docstring
  861. def forward(
  862. self,
  863. input_ids: torch.Tensor | None = None,
  864. attention_mask: torch.Tensor | None = None,
  865. token_type_ids: torch.Tensor | None = None,
  866. position_ids: torch.Tensor | None = None,
  867. inputs_embeds: torch.Tensor | None = None,
  868. labels: torch.Tensor | None = None,
  869. output_hidden_states: bool | None = None,
  870. return_dict: bool | None = None,
  871. **kwargs,
  872. ) -> tuple | MultipleChoiceModelOutput:
  873. r"""
  874. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  875. Indices of input sequence tokens in the vocabulary.
  876. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  877. [`PreTrainedTokenizer.__call__`] for details.
  878. [What are input IDs?](../glossary#input-ids)
  879. token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  880. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  881. 1]`:
  882. - 0 corresponds to a *sentence A* token,
  883. - 1 corresponds to a *sentence B* token.
  884. [What are token type IDs?](../glossary#token-type-ids)
  885. position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  886. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  887. config.max_position_embeddings - 1]`.
  888. [What are position IDs?](../glossary#position-ids)
  889. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  890. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  891. is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
  892. model's internal embedding lookup matrix.
  893. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  894. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  895. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  896. `input_ids` above)
  897. """
  898. return_dict = return_dict if return_dict is not None else self.config.return_dict
  899. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  900. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  901. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  902. token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  903. position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  904. inputs_embeds = (
  905. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  906. if inputs_embeds is not None
  907. else None
  908. )
  909. outputs = self.mra(
  910. input_ids,
  911. attention_mask=attention_mask,
  912. token_type_ids=token_type_ids,
  913. position_ids=position_ids,
  914. inputs_embeds=inputs_embeds,
  915. output_hidden_states=output_hidden_states,
  916. return_dict=return_dict,
  917. )
  918. hidden_state = outputs[0] # (bs * num_choices, seq_len, dim)
  919. pooled_output = hidden_state[:, 0] # (bs * num_choices, dim)
  920. pooled_output = self.pre_classifier(pooled_output) # (bs * num_choices, dim)
  921. pooled_output = nn.ReLU()(pooled_output) # (bs * num_choices, dim)
  922. logits = self.classifier(pooled_output)
  923. reshaped_logits = logits.view(-1, num_choices)
  924. loss = None
  925. if labels is not None:
  926. loss_fct = CrossEntropyLoss()
  927. loss = loss_fct(reshaped_logits, labels)
  928. if not return_dict:
  929. output = (reshaped_logits,) + outputs[1:]
  930. return ((loss,) + output) if loss is not None else output
  931. return MultipleChoiceModelOutput(
  932. loss=loss,
  933. logits=reshaped_logits,
  934. hidden_states=outputs.hidden_states,
  935. attentions=outputs.attentions,
  936. )
  937. @auto_docstring
  938. class MraForTokenClassification(MraPreTrainedModel):
  939. def __init__(self, config):
  940. super().__init__(config)
  941. self.num_labels = config.num_labels
  942. self.mra = MraModel(config)
  943. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  944. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  945. # Initialize weights and apply final processing
  946. self.post_init()
  947. @auto_docstring
  948. def forward(
  949. self,
  950. input_ids: torch.Tensor | None = None,
  951. attention_mask: torch.Tensor | None = None,
  952. token_type_ids: torch.Tensor | None = None,
  953. position_ids: torch.Tensor | None = None,
  954. inputs_embeds: torch.Tensor | None = None,
  955. labels: torch.Tensor | None = None,
  956. output_hidden_states: bool | None = None,
  957. return_dict: bool | None = None,
  958. **kwargs,
  959. ) -> tuple | TokenClassifierOutput:
  960. r"""
  961. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  962. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  963. """
  964. return_dict = return_dict if return_dict is not None else self.config.return_dict
  965. outputs = self.mra(
  966. input_ids,
  967. attention_mask=attention_mask,
  968. token_type_ids=token_type_ids,
  969. position_ids=position_ids,
  970. inputs_embeds=inputs_embeds,
  971. output_hidden_states=output_hidden_states,
  972. return_dict=return_dict,
  973. )
  974. sequence_output = outputs[0]
  975. sequence_output = self.dropout(sequence_output)
  976. logits = self.classifier(sequence_output)
  977. loss = None
  978. if labels is not None:
  979. loss_fct = CrossEntropyLoss()
  980. # Only keep active parts of the loss
  981. if attention_mask is not None:
  982. active_loss = attention_mask.view(-1) == 1
  983. active_logits = logits.view(-1, self.num_labels)
  984. active_labels = torch.where(
  985. active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
  986. )
  987. loss = loss_fct(active_logits, active_labels)
  988. else:
  989. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  990. if not return_dict:
  991. output = (logits,) + outputs[1:]
  992. return ((loss,) + output) if loss is not None else output
  993. return TokenClassifierOutput(
  994. loss=loss,
  995. logits=logits,
  996. hidden_states=outputs.hidden_states,
  997. attentions=outputs.attentions,
  998. )
  999. @auto_docstring
  1000. class MraForQuestionAnswering(MraPreTrainedModel):
  1001. def __init__(self, config):
  1002. super().__init__(config)
  1003. config.num_labels = 2
  1004. self.num_labels = config.num_labels
  1005. self.mra = MraModel(config)
  1006. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  1007. # Initialize weights and apply final processing
  1008. self.post_init()
  1009. @auto_docstring
  1010. def forward(
  1011. self,
  1012. input_ids: torch.Tensor | None = None,
  1013. attention_mask: torch.Tensor | None = None,
  1014. token_type_ids: torch.Tensor | None = None,
  1015. position_ids: torch.Tensor | None = None,
  1016. inputs_embeds: torch.Tensor | None = None,
  1017. start_positions: torch.Tensor | None = None,
  1018. end_positions: torch.Tensor | None = None,
  1019. output_hidden_states: bool | None = None,
  1020. return_dict: bool | None = None,
  1021. **kwargs,
  1022. ) -> tuple | QuestionAnsweringModelOutput:
  1023. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1024. outputs = self.mra(
  1025. input_ids,
  1026. attention_mask=attention_mask,
  1027. token_type_ids=token_type_ids,
  1028. position_ids=position_ids,
  1029. inputs_embeds=inputs_embeds,
  1030. output_hidden_states=output_hidden_states,
  1031. return_dict=return_dict,
  1032. )
  1033. sequence_output = outputs[0]
  1034. logits = self.qa_outputs(sequence_output)
  1035. start_logits, end_logits = logits.split(1, dim=-1)
  1036. start_logits = start_logits.squeeze(-1)
  1037. end_logits = end_logits.squeeze(-1)
  1038. total_loss = None
  1039. if start_positions is not None and end_positions is not None:
  1040. # If we are on multi-GPU, split add a dimension
  1041. if len(start_positions.size()) > 1:
  1042. start_positions = start_positions.squeeze(-1)
  1043. if len(end_positions.size()) > 1:
  1044. end_positions = end_positions.squeeze(-1)
  1045. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1046. ignored_index = start_logits.size(1)
  1047. start_positions = start_positions.clamp(0, ignored_index)
  1048. end_positions = end_positions.clamp(0, ignored_index)
  1049. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1050. start_loss = loss_fct(start_logits, start_positions)
  1051. end_loss = loss_fct(end_logits, end_positions)
  1052. total_loss = (start_loss + end_loss) / 2
  1053. if not return_dict:
  1054. output = (start_logits, end_logits) + outputs[1:]
  1055. return ((total_loss,) + output) if total_loss is not None else output
  1056. return QuestionAnsweringModelOutput(
  1057. loss=total_loss,
  1058. start_logits=start_logits,
  1059. end_logits=end_logits,
  1060. hidden_states=outputs.hidden_states,
  1061. attentions=outputs.attentions,
  1062. )
  1063. __all__ = [
  1064. "MraForMaskedLM",
  1065. "MraForMultipleChoice",
  1066. "MraForQuestionAnswering",
  1067. "MraForSequenceClassification",
  1068. "MraForTokenClassification",
  1069. "MraLayer",
  1070. "MraModel",
  1071. "MraPreTrainedModel",
  1072. ]