modeling_funnel.py 56 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364
  1. # Copyright 2020-present Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch Funnel Transformer model."""
  15. from dataclasses import dataclass
  16. import numpy as np
  17. import torch
  18. from torch import nn
  19. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  20. from ... import initialization as init
  21. from ...activations import ACT2FN
  22. from ...modeling_outputs import (
  23. BaseModelOutput,
  24. MaskedLMOutput,
  25. MultipleChoiceModelOutput,
  26. QuestionAnsweringModelOutput,
  27. SequenceClassifierOutput,
  28. TokenClassifierOutput,
  29. )
  30. from ...modeling_utils import PreTrainedModel
  31. from ...utils import ModelOutput, auto_docstring, logging
  32. from .configuration_funnel import FunnelConfig
  33. logger = logging.get_logger(__name__)
  34. INF = 1e6
  35. class FunnelEmbeddings(nn.Module):
  36. def __init__(self, config: FunnelConfig) -> None:
  37. super().__init__()
  38. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  39. self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
  40. self.dropout = nn.Dropout(config.hidden_dropout)
  41. def forward(
  42. self, input_ids: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None
  43. ) -> torch.Tensor:
  44. if inputs_embeds is None:
  45. inputs_embeds = self.word_embeddings(input_ids)
  46. embeddings = self.layer_norm(inputs_embeds)
  47. embeddings = self.dropout(embeddings)
  48. return embeddings
  49. class FunnelAttentionStructure(nn.Module):
  50. """
  51. Contains helpers for `FunnelRelMultiheadAttention `.
  52. """
  53. cls_token_type_id: int = 2
  54. def __init__(self, config: FunnelConfig) -> None:
  55. super().__init__()
  56. self.config = config
  57. self.sin_dropout = nn.Dropout(config.hidden_dropout)
  58. self.cos_dropout = nn.Dropout(config.hidden_dropout)
  59. # Track where we are at in terms of pooling from the original input, e.g., by how much the sequence length was
  60. # divided.
  61. self.pooling_mult = None
  62. def init_attention_inputs(
  63. self,
  64. inputs_embeds: torch.Tensor,
  65. attention_mask: torch.Tensor | None = None,
  66. token_type_ids: torch.Tensor | None = None,
  67. ) -> tuple[torch.Tensor]:
  68. """Returns the attention inputs associated to the inputs of the model."""
  69. # inputs_embeds has shape batch_size x seq_len x d_model
  70. # attention_mask and token_type_ids have shape batch_size x seq_len
  71. self.pooling_mult = 1
  72. self.seq_len = seq_len = inputs_embeds.size(1)
  73. position_embeds = self.get_position_embeds(seq_len, inputs_embeds.dtype, inputs_embeds.device)
  74. token_type_mat = self.token_type_ids_to_mat(token_type_ids) if token_type_ids is not None else None
  75. cls_mask = (
  76. nn.functional.pad(inputs_embeds.new_ones([seq_len - 1, seq_len - 1]), (1, 0, 1, 0))
  77. if self.config.separate_cls
  78. else None
  79. )
  80. return (position_embeds, token_type_mat, attention_mask, cls_mask)
  81. def token_type_ids_to_mat(self, token_type_ids: torch.Tensor) -> torch.Tensor:
  82. """Convert `token_type_ids` to `token_type_mat`."""
  83. token_type_mat = token_type_ids[:, :, None] == token_type_ids[:, None]
  84. # Treat <cls> as in the same segment as both A & B
  85. cls_ids = token_type_ids == self.cls_token_type_id
  86. cls_mat = cls_ids[:, :, None] | cls_ids[:, None]
  87. return cls_mat | token_type_mat
  88. def get_position_embeds(
  89. self, seq_len: int, dtype: torch.dtype, device: torch.device
  90. ) -> tuple[torch.Tensor] | list[list[torch.Tensor]]:
  91. """
  92. Create and cache inputs related to relative position encoding. Those are very different depending on whether we
  93. are using the factorized or the relative shift attention:
  94. For the factorized attention, it returns the matrices (phi, pi, psi, omega) used in the paper, appendix A.2.2,
  95. final formula.
  96. For the relative shift attention, it returns all possible vectors R used in the paper, appendix A.2.1, final
  97. formula.
  98. Paper link: https://huggingface.co/papers/2006.03236
  99. """
  100. d_model = self.config.d_model
  101. if self.config.attention_type == "factorized":
  102. # Notations from the paper, appending A.2.2, final formula.
  103. # We need to create and return the matrices phi, psi, pi and omega.
  104. pos_seq = torch.arange(0, seq_len, 1.0, dtype=torch.int64, device=device).to(dtype)
  105. freq_seq = torch.arange(0, d_model // 2, 1.0, dtype=torch.int64, device=device).to(dtype)
  106. inv_freq = 1 / (10000 ** (freq_seq / (d_model // 2)))
  107. sinusoid = pos_seq[:, None] * inv_freq[None]
  108. sin_embed = torch.sin(sinusoid)
  109. sin_embed_d = self.sin_dropout(sin_embed)
  110. cos_embed = torch.cos(sinusoid)
  111. cos_embed_d = self.cos_dropout(cos_embed)
  112. # This is different from the formula on the paper...
  113. phi = torch.cat([sin_embed_d, sin_embed_d], dim=-1)
  114. psi = torch.cat([cos_embed, sin_embed], dim=-1)
  115. pi = torch.cat([cos_embed_d, cos_embed_d], dim=-1)
  116. omega = torch.cat([-sin_embed, cos_embed], dim=-1)
  117. return (phi, pi, psi, omega)
  118. else:
  119. # Notations from the paper, appending A.2.1, final formula.
  120. # We need to create and return all the possible vectors R for all blocks and shifts.
  121. freq_seq = torch.arange(0, d_model // 2, 1.0, dtype=torch.int64, device=device).to(dtype)
  122. inv_freq = 1 / (10000 ** (freq_seq / (d_model // 2)))
  123. # Maximum relative positions for the first input
  124. rel_pos_id = torch.arange(-seq_len * 2, seq_len * 2, 1.0, dtype=torch.int64, device=device).to(dtype)
  125. zero_offset = seq_len * 2
  126. sinusoid = rel_pos_id[:, None] * inv_freq[None]
  127. sin_embed = self.sin_dropout(torch.sin(sinusoid))
  128. cos_embed = self.cos_dropout(torch.cos(sinusoid))
  129. pos_embed = torch.cat([sin_embed, cos_embed], dim=-1)
  130. pos = torch.arange(0, seq_len, dtype=torch.int64, device=device).to(dtype)
  131. pooled_pos = pos
  132. position_embeds_list = []
  133. for block_index in range(0, self.config.num_blocks):
  134. # For each block with block_index > 0, we need two types position embeddings:
  135. # - Attention(pooled-q, unpooled-kv)
  136. # - Attention(pooled-q, pooled-kv)
  137. # For block_index = 0 we only need the second one and leave the first one as None.
  138. # First type
  139. if block_index == 0:
  140. position_embeds_pooling = None
  141. else:
  142. pooled_pos = self.stride_pool_pos(pos, block_index)
  143. # construct rel_pos_id
  144. stride = 2 ** (block_index - 1)
  145. rel_pos = self.relative_pos(pos, stride, pooled_pos, shift=2)
  146. rel_pos = rel_pos[:, None] + zero_offset
  147. rel_pos = rel_pos.expand(rel_pos.size(0), d_model)
  148. position_embeds_pooling = torch.gather(pos_embed, 0, rel_pos)
  149. # Second type
  150. pos = pooled_pos
  151. stride = 2**block_index
  152. rel_pos = self.relative_pos(pos, stride)
  153. rel_pos = rel_pos[:, None] + zero_offset
  154. rel_pos = rel_pos.expand(rel_pos.size(0), d_model)
  155. position_embeds_no_pooling = torch.gather(pos_embed, 0, rel_pos)
  156. position_embeds_list.append([position_embeds_no_pooling, position_embeds_pooling])
  157. return position_embeds_list
  158. def stride_pool_pos(self, pos_id: torch.Tensor, block_index: int):
  159. """
  160. Pool `pos_id` while keeping the cls token separate (if `config.separate_cls=True`).
  161. """
  162. if self.config.separate_cls:
  163. # Under separate <cls>, we treat the <cls> as the first token in
  164. # the previous block of the 1st real block. Since the 1st real
  165. # block always has position 1, the position of the previous block
  166. # will be at `1 - 2 ** block_index`.
  167. cls_pos = pos_id.new_tensor([-(2**block_index) + 1])
  168. pooled_pos_id = pos_id[1:-1] if self.config.truncate_seq else pos_id[1:]
  169. return torch.cat([cls_pos, pooled_pos_id[::2]], 0)
  170. else:
  171. return pos_id[::2]
  172. def relative_pos(self, pos: torch.Tensor, stride: int, pooled_pos=None, shift: int = 1) -> torch.Tensor:
  173. """
  174. Build the relative positional vector between `pos` and `pooled_pos`.
  175. """
  176. if pooled_pos is None:
  177. pooled_pos = pos
  178. ref_point = pooled_pos[0] - pos[0]
  179. num_remove = shift * len(pooled_pos)
  180. max_dist = ref_point + num_remove * stride
  181. min_dist = pooled_pos[0] - pos[-1]
  182. return torch.arange(max_dist, min_dist - 1, -stride, dtype=torch.long, device=pos.device)
  183. def stride_pool(
  184. self,
  185. tensor: torch.Tensor | tuple[torch.Tensor] | list[torch.Tensor],
  186. axis: int | tuple[int] | list[int],
  187. ) -> torch.Tensor:
  188. """
  189. Perform pooling by stride slicing the tensor along the given axis.
  190. """
  191. if tensor is None:
  192. return None
  193. # Do the stride pool recursively if axis is a list or a tuple of ints.
  194. if isinstance(axis, (list, tuple)):
  195. for ax in axis:
  196. tensor = self.stride_pool(tensor, ax)
  197. return tensor
  198. # Do the stride pool recursively if tensor is a list or tuple of tensors.
  199. if isinstance(tensor, (tuple, list)):
  200. return type(tensor)(self.stride_pool(x, axis) for x in tensor)
  201. # Deal with negative axis
  202. axis %= tensor.ndim
  203. axis_slice = (
  204. slice(None, -1, 2) if self.config.separate_cls and self.config.truncate_seq else slice(None, None, 2)
  205. )
  206. enc_slice = tuple([slice(None)] * axis + [axis_slice])
  207. if self.config.separate_cls:
  208. cls_slice = tuple([slice(None)] * axis + [slice(None, 1)])
  209. tensor = torch.cat([tensor[cls_slice], tensor], axis=axis)
  210. return tensor[enc_slice]
  211. def pool_tensor(
  212. self, tensor: torch.Tensor | tuple[torch.Tensor] | list[torch.Tensor], mode: str = "mean", stride: int = 2
  213. ) -> torch.Tensor:
  214. """Apply 1D pooling to a tensor of size [B x T (x H)]."""
  215. if tensor is None:
  216. return None
  217. # Do the pool recursively if tensor is a list or tuple of tensors.
  218. if isinstance(tensor, (tuple, list)):
  219. return type(tensor)(self.pool_tensor(tensor, mode=mode, stride=stride) for x in tensor)
  220. if self.config.separate_cls:
  221. suffix = tensor[:, :-1] if self.config.truncate_seq else tensor
  222. tensor = torch.cat([tensor[:, :1], suffix], dim=1)
  223. ndim = tensor.ndim
  224. if ndim == 2:
  225. tensor = tensor[:, None, :, None]
  226. elif ndim == 3:
  227. tensor = tensor[:, None, :, :]
  228. # Stride is applied on the second-to-last dimension.
  229. stride = (stride, 1)
  230. if mode == "mean":
  231. tensor = nn.functional.avg_pool2d(tensor, stride, stride=stride, ceil_mode=True)
  232. elif mode == "max":
  233. tensor = nn.functional.max_pool2d(tensor, stride, stride=stride, ceil_mode=True)
  234. elif mode == "min":
  235. tensor = -nn.functional.max_pool2d(-tensor, stride, stride=stride, ceil_mode=True)
  236. else:
  237. raise NotImplementedError("The supported modes are 'mean', 'max' and 'min'.")
  238. if ndim == 2:
  239. return tensor[:, 0, :, 0]
  240. elif ndim == 3:
  241. return tensor[:, 0]
  242. return tensor
  243. def pre_attention_pooling(
  244. self, output, attention_inputs: tuple[torch.Tensor]
  245. ) -> tuple[torch.Tensor, tuple[torch.Tensor]]:
  246. """Pool `output` and the proper parts of `attention_inputs` before the attention layer."""
  247. position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs
  248. if self.config.pool_q_only:
  249. if self.config.attention_type == "factorized":
  250. position_embeds = self.stride_pool(position_embeds[:2], 0) + position_embeds[2:]
  251. token_type_mat = self.stride_pool(token_type_mat, 1)
  252. cls_mask = self.stride_pool(cls_mask, 0)
  253. output = self.pool_tensor(output, mode=self.config.pooling_type)
  254. else:
  255. self.pooling_mult *= 2
  256. if self.config.attention_type == "factorized":
  257. position_embeds = self.stride_pool(position_embeds, 0)
  258. token_type_mat = self.stride_pool(token_type_mat, [1, 2])
  259. cls_mask = self.stride_pool(cls_mask, [1, 2])
  260. attention_mask = self.pool_tensor(attention_mask, mode="min")
  261. output = self.pool_tensor(output, mode=self.config.pooling_type)
  262. attention_inputs = (position_embeds, token_type_mat, attention_mask, cls_mask)
  263. return output, attention_inputs
  264. def post_attention_pooling(self, attention_inputs: tuple[torch.Tensor]) -> tuple[torch.Tensor]:
  265. """Pool the proper parts of `attention_inputs` after the attention layer."""
  266. position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs
  267. if self.config.pool_q_only:
  268. self.pooling_mult *= 2
  269. if self.config.attention_type == "factorized":
  270. position_embeds = position_embeds[:2] + self.stride_pool(position_embeds[2:], 0)
  271. token_type_mat = self.stride_pool(token_type_mat, 2)
  272. cls_mask = self.stride_pool(cls_mask, 1)
  273. attention_mask = self.pool_tensor(attention_mask, mode="min")
  274. attention_inputs = (position_embeds, token_type_mat, attention_mask, cls_mask)
  275. return attention_inputs
  276. def _relative_shift_gather(positional_attn: torch.Tensor, context_len: int, shift: int) -> torch.Tensor:
  277. batch_size, n_head, seq_len, max_rel_len = positional_attn.shape
  278. # max_rel_len = 2 * context_len + shift -1 is the numbers of possible relative positions i-j
  279. # What's next is the same as doing the following gather, which might be clearer code but less efficient.
  280. # idxs = context_len + torch.arange(0, context_len).unsqueeze(0) - torch.arange(0, seq_len).unsqueeze(1)
  281. # # matrix of context_len + i-j
  282. # return positional_attn.gather(3, idxs.expand([batch_size, n_head, context_len, context_len]))
  283. positional_attn = torch.reshape(positional_attn, [batch_size, n_head, max_rel_len, seq_len])
  284. positional_attn = positional_attn[:, :, shift:, :]
  285. positional_attn = torch.reshape(positional_attn, [batch_size, n_head, seq_len, max_rel_len - shift])
  286. positional_attn = positional_attn[..., :context_len]
  287. return positional_attn
  288. class FunnelRelMultiheadAttention(nn.Module):
  289. def __init__(self, config: FunnelConfig, block_index: int) -> None:
  290. super().__init__()
  291. self.config = config
  292. self.block_index = block_index
  293. d_model, n_head, d_head = config.d_model, config.n_head, config.d_head
  294. self.hidden_dropout = nn.Dropout(config.hidden_dropout)
  295. self.attention_dropout = nn.Dropout(config.attention_dropout)
  296. self.q_head = nn.Linear(d_model, n_head * d_head, bias=False)
  297. self.k_head = nn.Linear(d_model, n_head * d_head)
  298. self.v_head = nn.Linear(d_model, n_head * d_head)
  299. self.r_w_bias = nn.Parameter(torch.zeros([n_head, d_head]))
  300. self.r_r_bias = nn.Parameter(torch.zeros([n_head, d_head]))
  301. self.r_kernel = nn.Parameter(torch.zeros([d_model, n_head, d_head]))
  302. self.r_s_bias = nn.Parameter(torch.zeros([n_head, d_head]))
  303. self.seg_embed = nn.Parameter(torch.zeros([2, n_head, d_head]))
  304. self.post_proj = nn.Linear(n_head * d_head, d_model)
  305. self.layer_norm = nn.LayerNorm(d_model, eps=config.layer_norm_eps)
  306. self.scale = 1.0 / (d_head**0.5)
  307. def relative_positional_attention(self, position_embeds, q_head, context_len, cls_mask=None):
  308. """Relative attention score for the positional encodings"""
  309. # q_head has shape batch_size x sea_len x n_head x d_head
  310. if self.config.attention_type == "factorized":
  311. # Notations from the paper, appending A.2.2, final formula (https://huggingface.co/papers/2006.03236)
  312. # phi and pi have shape seq_len x d_model, psi and omega have shape context_len x d_model
  313. phi, pi, psi, omega = position_embeds
  314. # Shape n_head x d_head
  315. u = self.r_r_bias * self.scale
  316. # Shape d_model x n_head x d_head
  317. w_r = self.r_kernel
  318. # Shape batch_size x sea_len x n_head x d_model
  319. q_r_attention = torch.einsum("binh,dnh->bind", q_head + u, w_r)
  320. q_r_attention_1 = q_r_attention * phi[:, None]
  321. q_r_attention_2 = q_r_attention * pi[:, None]
  322. # Shape batch_size x n_head x seq_len x context_len
  323. positional_attn = torch.einsum("bind,jd->bnij", q_r_attention_1, psi) + torch.einsum(
  324. "bind,jd->bnij", q_r_attention_2, omega
  325. )
  326. else:
  327. shift = 2 if q_head.shape[1] != context_len else 1
  328. # Notations from the paper, appending A.2.1, final formula (https://huggingface.co/papers/2006.03236)
  329. # Grab the proper positional encoding, shape max_rel_len x d_model
  330. r = position_embeds[self.block_index][shift - 1]
  331. # Shape n_head x d_head
  332. v = self.r_r_bias * self.scale
  333. # Shape d_model x n_head x d_head
  334. w_r = self.r_kernel
  335. # Shape max_rel_len x n_head x d_model
  336. r_head = torch.einsum("td,dnh->tnh", r, w_r)
  337. # Shape batch_size x n_head x seq_len x max_rel_len
  338. positional_attn = torch.einsum("binh,tnh->bnit", q_head + v, r_head)
  339. # Shape batch_size x n_head x seq_len x context_len
  340. positional_attn = _relative_shift_gather(positional_attn, context_len, shift)
  341. if cls_mask is not None:
  342. positional_attn *= cls_mask
  343. return positional_attn
  344. def relative_token_type_attention(self, token_type_mat, q_head, cls_mask=None):
  345. """Relative attention score for the token_type_ids"""
  346. if token_type_mat is None:
  347. return 0
  348. batch_size, seq_len, context_len = token_type_mat.shape
  349. # q_head has shape batch_size x seq_len x n_head x d_head
  350. # Shape n_head x d_head
  351. r_s_bias = self.r_s_bias * self.scale
  352. # Shape batch_size x n_head x seq_len x 2
  353. token_type_bias = torch.einsum("bind,snd->bnis", q_head + r_s_bias, self.seg_embed)
  354. # Shape batch_size x n_head x seq_len x context_len
  355. token_type_mat = token_type_mat[:, None].expand([batch_size, q_head.shape[2], seq_len, context_len])
  356. # Shapes batch_size x n_head x seq_len
  357. diff_token_type, same_token_type = torch.split(token_type_bias, 1, dim=-1)
  358. # Shape batch_size x n_head x seq_len x context_len
  359. token_type_attn = torch.where(
  360. token_type_mat, same_token_type.expand(token_type_mat.shape), diff_token_type.expand(token_type_mat.shape)
  361. )
  362. if cls_mask is not None:
  363. token_type_attn *= cls_mask
  364. return token_type_attn
  365. def forward(
  366. self,
  367. query: torch.Tensor,
  368. key: torch.Tensor,
  369. value: torch.Tensor,
  370. attention_inputs: tuple[torch.Tensor],
  371. output_attentions: bool = False,
  372. ) -> tuple[torch.Tensor, ...]:
  373. # query has shape batch_size x seq_len x d_model
  374. # key and value have shapes batch_size x context_len x d_model
  375. position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs
  376. batch_size, seq_len, _ = query.shape
  377. context_len = key.shape[1]
  378. n_head, d_head = self.config.n_head, self.config.d_head
  379. # Shape batch_size x seq_len x n_head x d_head
  380. q_head = self.q_head(query).view(batch_size, seq_len, n_head, d_head)
  381. # Shapes batch_size x context_len x n_head x d_head
  382. k_head = self.k_head(key).view(batch_size, context_len, n_head, d_head)
  383. v_head = self.v_head(value).view(batch_size, context_len, n_head, d_head)
  384. q_head = q_head * self.scale
  385. # Shape n_head x d_head
  386. r_w_bias = self.r_w_bias * self.scale
  387. # Shapes batch_size x n_head x seq_len x context_len
  388. content_score = torch.einsum("bind,bjnd->bnij", q_head + r_w_bias, k_head)
  389. positional_attn = self.relative_positional_attention(position_embeds, q_head, context_len, cls_mask)
  390. token_type_attn = self.relative_token_type_attention(token_type_mat, q_head, cls_mask)
  391. # merge attention scores
  392. attn_score = content_score + positional_attn + token_type_attn
  393. # precision safe in case of mixed precision training
  394. dtype = attn_score.dtype
  395. attn_score = attn_score.float()
  396. # perform masking
  397. if attention_mask is not None:
  398. attn_score = attn_score - INF * (1 - attention_mask[:, None, None].float())
  399. # attention probability
  400. attn_prob = torch.softmax(attn_score, dim=-1, dtype=dtype)
  401. attn_prob = self.attention_dropout(attn_prob)
  402. # attention output, shape batch_size x seq_len x n_head x d_head
  403. attn_vec = torch.einsum("bnij,bjnd->bind", attn_prob, v_head)
  404. # Shape shape batch_size x seq_len x d_model
  405. attn_out = self.post_proj(attn_vec.reshape(batch_size, seq_len, n_head * d_head))
  406. attn_out = self.hidden_dropout(attn_out)
  407. output = self.layer_norm(query + attn_out)
  408. return (output, attn_prob) if output_attentions else (output,)
  409. class FunnelPositionwiseFFN(nn.Module):
  410. def __init__(self, config: FunnelConfig) -> None:
  411. super().__init__()
  412. self.linear_1 = nn.Linear(config.d_model, config.d_inner)
  413. self.activation_function = ACT2FN[config.hidden_act]
  414. self.activation_dropout = nn.Dropout(config.activation_dropout)
  415. self.linear_2 = nn.Linear(config.d_inner, config.d_model)
  416. self.dropout = nn.Dropout(config.hidden_dropout)
  417. self.layer_norm = nn.LayerNorm(config.d_model, config.layer_norm_eps)
  418. def forward(self, hidden: torch.Tensor) -> torch.Tensor:
  419. h = self.linear_1(hidden)
  420. h = self.activation_function(h)
  421. h = self.activation_dropout(h)
  422. h = self.linear_2(h)
  423. h = self.dropout(h)
  424. return self.layer_norm(hidden + h)
  425. class FunnelLayer(nn.Module):
  426. def __init__(self, config: FunnelConfig, block_index: int) -> None:
  427. super().__init__()
  428. self.attention = FunnelRelMultiheadAttention(config, block_index)
  429. self.ffn = FunnelPositionwiseFFN(config)
  430. def forward(
  431. self,
  432. query: torch.Tensor,
  433. key: torch.Tensor,
  434. value: torch.Tensor,
  435. attention_inputs,
  436. output_attentions: bool = False,
  437. ) -> tuple:
  438. attn = self.attention(query, key, value, attention_inputs, output_attentions=output_attentions)
  439. output = self.ffn(attn[0])
  440. return (output, attn[1]) if output_attentions else (output,)
  441. class FunnelEncoder(nn.Module):
  442. def __init__(self, config: FunnelConfig) -> None:
  443. super().__init__()
  444. self.config = config
  445. self.attention_structure = FunnelAttentionStructure(config)
  446. self.blocks = nn.ModuleList(
  447. [
  448. nn.ModuleList([FunnelLayer(config, block_index) for _ in range(block_size)])
  449. for block_index, block_size in enumerate(config.block_sizes)
  450. ]
  451. )
  452. def forward(
  453. self,
  454. inputs_embeds: torch.Tensor,
  455. attention_mask: torch.Tensor | None = None,
  456. token_type_ids: torch.Tensor | None = None,
  457. output_attentions: bool = False,
  458. output_hidden_states: bool = False,
  459. return_dict: bool = True,
  460. ) -> tuple | BaseModelOutput:
  461. # The pooling is not implemented on long tensors, so we convert this mask.
  462. attention_mask = attention_mask.type_as(inputs_embeds)
  463. attention_inputs = self.attention_structure.init_attention_inputs(
  464. inputs_embeds,
  465. attention_mask=attention_mask,
  466. token_type_ids=token_type_ids,
  467. )
  468. hidden = inputs_embeds
  469. all_hidden_states = (inputs_embeds,) if output_hidden_states else None
  470. all_attentions = () if output_attentions else None
  471. for block_index, block in enumerate(self.blocks):
  472. pooling_flag = hidden.size(1) > (2 if self.config.separate_cls else 1)
  473. pooling_flag = pooling_flag and block_index > 0
  474. if pooling_flag:
  475. pooled_hidden, attention_inputs = self.attention_structure.pre_attention_pooling(
  476. hidden, attention_inputs
  477. )
  478. for layer_index, layer in enumerate(block):
  479. for repeat_index in range(self.config.block_repeats[block_index]):
  480. do_pooling = (repeat_index == 0) and (layer_index == 0) and pooling_flag
  481. if do_pooling:
  482. query = pooled_hidden
  483. key = value = hidden if self.config.pool_q_only else pooled_hidden
  484. else:
  485. query = key = value = hidden
  486. layer_output = layer(query, key, value, attention_inputs, output_attentions=output_attentions)
  487. hidden = layer_output[0]
  488. if do_pooling:
  489. attention_inputs = self.attention_structure.post_attention_pooling(attention_inputs)
  490. if output_attentions:
  491. all_attentions = all_attentions + layer_output[1:]
  492. if output_hidden_states:
  493. all_hidden_states = all_hidden_states + (hidden,)
  494. if not return_dict:
  495. return tuple(v for v in [hidden, all_hidden_states, all_attentions] if v is not None)
  496. return BaseModelOutput(last_hidden_state=hidden, hidden_states=all_hidden_states, attentions=all_attentions)
  497. def upsample(
  498. x: torch.Tensor, stride: int, target_len: int, separate_cls: bool = True, truncate_seq: bool = False
  499. ) -> torch.Tensor:
  500. """
  501. Upsample tensor `x` to match `target_len` by repeating the tokens `stride` time on the sequence length dimension.
  502. """
  503. if stride == 1:
  504. return x
  505. if separate_cls:
  506. cls = x[:, :1]
  507. x = x[:, 1:]
  508. output = torch.repeat_interleave(x, repeats=stride, dim=1)
  509. if separate_cls:
  510. if truncate_seq:
  511. output = nn.functional.pad(output, (0, 0, 0, stride - 1, 0, 0))
  512. output = output[:, : target_len - 1]
  513. output = torch.cat([cls, output], dim=1)
  514. else:
  515. output = output[:, :target_len]
  516. return output
  517. class FunnelDecoder(nn.Module):
  518. def __init__(self, config: FunnelConfig) -> None:
  519. super().__init__()
  520. self.config = config
  521. self.attention_structure = FunnelAttentionStructure(config)
  522. self.layers = nn.ModuleList([FunnelLayer(config, 0) for _ in range(config.num_decoder_layers)])
  523. def forward(
  524. self,
  525. final_hidden: torch.Tensor,
  526. first_block_hidden: torch.Tensor,
  527. attention_mask: torch.Tensor | None = None,
  528. token_type_ids: torch.Tensor | None = None,
  529. output_attentions: bool = False,
  530. output_hidden_states: bool = False,
  531. return_dict: bool = True,
  532. ) -> tuple | BaseModelOutput:
  533. upsampled_hidden = upsample(
  534. final_hidden,
  535. stride=2 ** (len(self.config.block_sizes) - 1),
  536. target_len=first_block_hidden.shape[1],
  537. separate_cls=self.config.separate_cls,
  538. truncate_seq=self.config.truncate_seq,
  539. )
  540. hidden = upsampled_hidden + first_block_hidden
  541. all_hidden_states = (hidden,) if output_hidden_states else None
  542. all_attentions = () if output_attentions else None
  543. attention_inputs = self.attention_structure.init_attention_inputs(
  544. hidden,
  545. attention_mask=attention_mask,
  546. token_type_ids=token_type_ids,
  547. )
  548. for layer in self.layers:
  549. layer_output = layer(hidden, hidden, hidden, attention_inputs, output_attentions=output_attentions)
  550. hidden = layer_output[0]
  551. if output_attentions:
  552. all_attentions = all_attentions + layer_output[1:]
  553. if output_hidden_states:
  554. all_hidden_states = all_hidden_states + (hidden,)
  555. if not return_dict:
  556. return tuple(v for v in [hidden, all_hidden_states, all_attentions] if v is not None)
  557. return BaseModelOutput(last_hidden_state=hidden, hidden_states=all_hidden_states, attentions=all_attentions)
  558. class FunnelDiscriminatorPredictions(nn.Module):
  559. """Prediction module for the discriminator, made up of two dense layers."""
  560. def __init__(self, config: FunnelConfig) -> None:
  561. super().__init__()
  562. self.config = config
  563. self.dense = nn.Linear(config.d_model, config.d_model)
  564. self.dense_prediction = nn.Linear(config.d_model, 1)
  565. def forward(self, discriminator_hidden_states: torch.Tensor) -> torch.Tensor:
  566. hidden_states = self.dense(discriminator_hidden_states)
  567. hidden_states = ACT2FN[self.config.hidden_act](hidden_states)
  568. logits = self.dense_prediction(hidden_states).squeeze(-1)
  569. return logits
  570. @auto_docstring
  571. class FunnelPreTrainedModel(PreTrainedModel):
  572. config: FunnelConfig
  573. base_model_prefix = "funnel"
  574. @torch.no_grad()
  575. def _init_weights(self, module):
  576. classname = module.__class__.__name__
  577. if classname.find("Linear") != -1:
  578. if getattr(module, "weight", None) is not None:
  579. if self.config.initializer_std is None:
  580. fan_out, fan_in = module.weight.shape
  581. std = np.sqrt(1.0 / float(fan_in + fan_out))
  582. else:
  583. std = self.config.initializer_std
  584. init.normal_(module.weight, std=std)
  585. if getattr(module, "bias", None) is not None:
  586. init.constant_(module.bias, 0.0)
  587. elif classname == "FunnelRelMultiheadAttention":
  588. init.uniform_(module.r_w_bias, b=self.config.initializer_range)
  589. init.uniform_(module.r_r_bias, b=self.config.initializer_range)
  590. init.uniform_(module.r_kernel, b=self.config.initializer_range)
  591. init.uniform_(module.r_s_bias, b=self.config.initializer_range)
  592. init.uniform_(module.seg_embed, b=self.config.initializer_range)
  593. elif classname == "FunnelEmbeddings":
  594. std = 1.0 if self.config.initializer_std is None else self.config.initializer_std
  595. init.normal_(module.word_embeddings.weight, std=std)
  596. if module.word_embeddings.padding_idx is not None:
  597. init.zeros_(module.word_embeddings.weight[module.word_embeddings.padding_idx])
  598. class FunnelClassificationHead(nn.Module):
  599. def __init__(self, config: FunnelConfig, n_labels: int) -> None:
  600. super().__init__()
  601. self.linear_hidden = nn.Linear(config.d_model, config.d_model)
  602. self.dropout = nn.Dropout(config.hidden_dropout)
  603. self.linear_out = nn.Linear(config.d_model, n_labels)
  604. def forward(self, hidden: torch.Tensor) -> torch.Tensor:
  605. hidden = self.linear_hidden(hidden)
  606. hidden = torch.tanh(hidden)
  607. hidden = self.dropout(hidden)
  608. return self.linear_out(hidden)
  609. @dataclass
  610. @auto_docstring(
  611. custom_intro="""
  612. Output type of [`FunnelForPreTraining`].
  613. """
  614. )
  615. class FunnelForPreTrainingOutput(ModelOutput):
  616. r"""
  617. loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  618. Total loss of the ELECTRA-style objective.
  619. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  620. Prediction scores of the head (scores for each token before SoftMax).
  621. """
  622. loss: torch.FloatTensor | None = None
  623. logits: torch.FloatTensor | None = None
  624. hidden_states: tuple[torch.FloatTensor] | None = None
  625. attentions: tuple[torch.FloatTensor] | None = None
  626. @auto_docstring(
  627. custom_intro="""
  628. The base Funnel Transformer Model transformer outputting raw hidden-states without upsampling head (also called
  629. decoder) or any task-specific head on top.
  630. """
  631. )
  632. class FunnelBaseModel(FunnelPreTrainedModel):
  633. def __init__(self, config: FunnelConfig) -> None:
  634. super().__init__(config)
  635. self.embeddings = FunnelEmbeddings(config)
  636. self.encoder = FunnelEncoder(config)
  637. # Initialize weights and apply final processing
  638. self.post_init()
  639. def get_input_embeddings(self) -> nn.Embedding:
  640. return self.embeddings.word_embeddings
  641. def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
  642. self.embeddings.word_embeddings = new_embeddings
  643. @auto_docstring
  644. def forward(
  645. self,
  646. input_ids: torch.Tensor | None = None,
  647. attention_mask: torch.Tensor | None = None,
  648. token_type_ids: torch.Tensor | None = None,
  649. position_ids: torch.Tensor | None = None,
  650. inputs_embeds: torch.Tensor | None = None,
  651. output_attentions: bool | None = None,
  652. output_hidden_states: bool | None = None,
  653. return_dict: bool | None = None,
  654. **kwargs,
  655. ) -> tuple | BaseModelOutput:
  656. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  657. output_hidden_states = (
  658. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  659. )
  660. return_dict = return_dict if return_dict is not None else self.config.return_dict
  661. if input_ids is not None and inputs_embeds is not None:
  662. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  663. elif input_ids is not None:
  664. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  665. input_shape = input_ids.size()
  666. elif inputs_embeds is not None:
  667. input_shape = inputs_embeds.size()[:-1]
  668. else:
  669. raise ValueError("You have to specify either input_ids or inputs_embeds")
  670. device = input_ids.device if input_ids is not None else inputs_embeds.device
  671. if attention_mask is None:
  672. attention_mask = torch.ones(input_shape, device=device)
  673. if token_type_ids is None:
  674. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  675. inputs_embeds = self.embeddings(input_ids, inputs_embeds=inputs_embeds)
  676. encoder_outputs = self.encoder(
  677. inputs_embeds,
  678. attention_mask=attention_mask,
  679. token_type_ids=token_type_ids,
  680. output_attentions=output_attentions,
  681. output_hidden_states=output_hidden_states,
  682. return_dict=return_dict,
  683. )
  684. return encoder_outputs
  685. @auto_docstring
  686. class FunnelModel(FunnelPreTrainedModel):
  687. def __init__(self, config: FunnelConfig) -> None:
  688. super().__init__(config)
  689. self.config = config
  690. self.embeddings = FunnelEmbeddings(config)
  691. self.encoder = FunnelEncoder(config)
  692. self.decoder = FunnelDecoder(config)
  693. # Initialize weights and apply final processing
  694. self.post_init()
  695. def get_input_embeddings(self) -> nn.Embedding:
  696. return self.embeddings.word_embeddings
  697. def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
  698. self.embeddings.word_embeddings = new_embeddings
  699. @auto_docstring
  700. def forward(
  701. self,
  702. input_ids: torch.Tensor | None = None,
  703. attention_mask: torch.Tensor | None = None,
  704. token_type_ids: torch.Tensor | None = None,
  705. inputs_embeds: torch.Tensor | None = None,
  706. output_attentions: bool | None = None,
  707. output_hidden_states: bool | None = None,
  708. return_dict: bool | None = None,
  709. **kwargs,
  710. ) -> tuple | BaseModelOutput:
  711. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  712. output_hidden_states = (
  713. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  714. )
  715. return_dict = return_dict if return_dict is not None else self.config.return_dict
  716. if input_ids is not None and inputs_embeds is not None:
  717. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  718. elif input_ids is not None:
  719. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  720. input_shape = input_ids.size()
  721. elif inputs_embeds is not None:
  722. input_shape = inputs_embeds.size()[:-1]
  723. else:
  724. raise ValueError("You have to specify either input_ids or inputs_embeds")
  725. device = input_ids.device if input_ids is not None else inputs_embeds.device
  726. if attention_mask is None:
  727. attention_mask = torch.ones(input_shape, device=device)
  728. if token_type_ids is None:
  729. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  730. inputs_embeds = self.embeddings(input_ids, inputs_embeds=inputs_embeds)
  731. encoder_outputs = self.encoder(
  732. inputs_embeds,
  733. attention_mask=attention_mask,
  734. token_type_ids=token_type_ids,
  735. output_attentions=output_attentions,
  736. output_hidden_states=True,
  737. return_dict=return_dict,
  738. )
  739. decoder_outputs = self.decoder(
  740. final_hidden=encoder_outputs[0],
  741. first_block_hidden=encoder_outputs[1][self.config.block_sizes[0]],
  742. attention_mask=attention_mask,
  743. token_type_ids=token_type_ids,
  744. output_attentions=output_attentions,
  745. output_hidden_states=output_hidden_states,
  746. return_dict=return_dict,
  747. )
  748. if not return_dict:
  749. idx = 0
  750. outputs = (decoder_outputs[0],)
  751. if output_hidden_states:
  752. idx += 1
  753. outputs = outputs + (encoder_outputs[1] + decoder_outputs[idx],)
  754. if output_attentions:
  755. idx += 1
  756. outputs = outputs + (encoder_outputs[2] + decoder_outputs[idx],)
  757. return outputs
  758. return BaseModelOutput(
  759. last_hidden_state=decoder_outputs[0],
  760. hidden_states=(encoder_outputs.hidden_states + decoder_outputs.hidden_states)
  761. if output_hidden_states
  762. else None,
  763. attentions=(encoder_outputs.attentions + decoder_outputs.attentions) if output_attentions else None,
  764. )
  765. @auto_docstring(
  766. custom_intro="""
  767. Funnel Transformer model with a binary classification head on top as used during pretraining for identifying
  768. generated tokens.
  769. """
  770. )
  771. class FunnelForPreTraining(FunnelPreTrainedModel):
  772. def __init__(self, config: FunnelConfig) -> None:
  773. super().__init__(config)
  774. self.funnel = FunnelModel(config)
  775. self.discriminator_predictions = FunnelDiscriminatorPredictions(config)
  776. # Initialize weights and apply final processing
  777. self.post_init()
  778. @auto_docstring
  779. def forward(
  780. self,
  781. input_ids: torch.Tensor | None = None,
  782. attention_mask: torch.Tensor | None = None,
  783. token_type_ids: torch.Tensor | None = None,
  784. inputs_embeds: torch.Tensor | None = None,
  785. labels: torch.Tensor | None = None,
  786. output_attentions: bool | None = None,
  787. output_hidden_states: bool | None = None,
  788. return_dict: bool | None = None,
  789. **kwargs,
  790. ) -> tuple | FunnelForPreTrainingOutput:
  791. r"""
  792. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  793. Labels for computing the ELECTRA-style loss. Input should be a sequence of tokens (see `input_ids`
  794. docstring) Indices should be in `[0, 1]`:
  795. - 0 indicates the token is an original token,
  796. - 1 indicates the token was replaced.
  797. Examples:
  798. ```python
  799. >>> from transformers import AutoTokenizer, FunnelForPreTraining
  800. >>> import torch
  801. >>> tokenizer = AutoTokenizer.from_pretrained("funnel-transformer/small")
  802. >>> model = FunnelForPreTraining.from_pretrained("funnel-transformer/small")
  803. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  804. >>> logits = model(**inputs).logits
  805. ```"""
  806. return_dict = return_dict if return_dict is not None else self.config.return_dict
  807. discriminator_hidden_states = self.funnel(
  808. input_ids,
  809. attention_mask=attention_mask,
  810. token_type_ids=token_type_ids,
  811. inputs_embeds=inputs_embeds,
  812. output_attentions=output_attentions,
  813. output_hidden_states=output_hidden_states,
  814. return_dict=return_dict,
  815. )
  816. discriminator_sequence_output = discriminator_hidden_states[0]
  817. logits = self.discriminator_predictions(discriminator_sequence_output)
  818. loss = None
  819. if labels is not None:
  820. loss_fct = nn.BCEWithLogitsLoss()
  821. if attention_mask is not None:
  822. active_loss = attention_mask.view(-1, discriminator_sequence_output.shape[1]) == 1
  823. active_logits = logits.view(-1, discriminator_sequence_output.shape[1])[active_loss]
  824. active_labels = labels[active_loss]
  825. loss = loss_fct(active_logits, active_labels.float())
  826. else:
  827. loss = loss_fct(logits.view(-1, discriminator_sequence_output.shape[1]), labels.float())
  828. if not return_dict:
  829. output = (logits,) + discriminator_hidden_states[1:]
  830. return ((loss,) + output) if loss is not None else output
  831. return FunnelForPreTrainingOutput(
  832. loss=loss,
  833. logits=logits,
  834. hidden_states=discriminator_hidden_states.hidden_states,
  835. attentions=discriminator_hidden_states.attentions,
  836. )
  837. @auto_docstring
  838. class FunnelForMaskedLM(FunnelPreTrainedModel):
  839. _tied_weights_keys = {"lm_head.weight": "funnel.embeddings.word_embeddings.weight"}
  840. def __init__(self, config: FunnelConfig) -> None:
  841. super().__init__(config)
  842. self.funnel = FunnelModel(config)
  843. self.lm_head = nn.Linear(config.d_model, config.vocab_size)
  844. # Initialize weights and apply final processing
  845. self.post_init()
  846. def get_output_embeddings(self) -> nn.Linear:
  847. return self.lm_head
  848. def set_output_embeddings(self, new_embeddings: nn.Embedding) -> None:
  849. self.lm_head = new_embeddings
  850. @auto_docstring
  851. def forward(
  852. self,
  853. input_ids: torch.Tensor | None = None,
  854. attention_mask: torch.Tensor | None = None,
  855. token_type_ids: torch.Tensor | None = None,
  856. inputs_embeds: torch.Tensor | None = None,
  857. labels: torch.Tensor | None = None,
  858. output_attentions: bool | None = None,
  859. output_hidden_states: bool | None = None,
  860. return_dict: bool | None = None,
  861. **kwargs,
  862. ) -> tuple | MaskedLMOutput:
  863. r"""
  864. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  865. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  866. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  867. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  868. """
  869. return_dict = return_dict if return_dict is not None else self.config.return_dict
  870. outputs = self.funnel(
  871. input_ids,
  872. attention_mask=attention_mask,
  873. token_type_ids=token_type_ids,
  874. inputs_embeds=inputs_embeds,
  875. output_attentions=output_attentions,
  876. output_hidden_states=output_hidden_states,
  877. return_dict=return_dict,
  878. )
  879. last_hidden_state = outputs[0]
  880. prediction_logits = self.lm_head(last_hidden_state)
  881. masked_lm_loss = None
  882. if labels is not None:
  883. loss_fct = CrossEntropyLoss() # -100 index = padding token
  884. masked_lm_loss = loss_fct(prediction_logits.view(-1, self.config.vocab_size), labels.view(-1))
  885. if not return_dict:
  886. output = (prediction_logits,) + outputs[1:]
  887. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  888. return MaskedLMOutput(
  889. loss=masked_lm_loss,
  890. logits=prediction_logits,
  891. hidden_states=outputs.hidden_states,
  892. attentions=outputs.attentions,
  893. )
  894. @auto_docstring(
  895. custom_intro="""
  896. Funnel Transformer Model with a sequence classification/regression head on top (two linear layer on top of the
  897. first timestep of the last hidden state) e.g. for GLUE tasks.
  898. """
  899. )
  900. class FunnelForSequenceClassification(FunnelPreTrainedModel):
  901. def __init__(self, config: FunnelConfig) -> None:
  902. super().__init__(config)
  903. self.num_labels = config.num_labels
  904. self.config = config
  905. self.funnel = FunnelBaseModel(config)
  906. self.classifier = FunnelClassificationHead(config, config.num_labels)
  907. # Initialize weights and apply final processing
  908. self.post_init()
  909. @auto_docstring
  910. def forward(
  911. self,
  912. input_ids: torch.Tensor | None = None,
  913. attention_mask: torch.Tensor | None = None,
  914. token_type_ids: torch.Tensor | None = None,
  915. inputs_embeds: torch.Tensor | None = None,
  916. labels: torch.Tensor | None = None,
  917. output_attentions: bool | None = None,
  918. output_hidden_states: bool | None = None,
  919. return_dict: bool | None = None,
  920. **kwargs,
  921. ) -> tuple | SequenceClassifierOutput:
  922. r"""
  923. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  924. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  925. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  926. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  927. """
  928. return_dict = return_dict if return_dict is not None else self.config.return_dict
  929. outputs = self.funnel(
  930. input_ids,
  931. attention_mask=attention_mask,
  932. token_type_ids=token_type_ids,
  933. inputs_embeds=inputs_embeds,
  934. output_attentions=output_attentions,
  935. output_hidden_states=output_hidden_states,
  936. return_dict=return_dict,
  937. )
  938. last_hidden_state = outputs[0]
  939. pooled_output = last_hidden_state[:, 0]
  940. logits = self.classifier(pooled_output)
  941. loss = None
  942. if labels is not None:
  943. if self.config.problem_type is None:
  944. if self.num_labels == 1:
  945. self.config.problem_type = "regression"
  946. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  947. self.config.problem_type = "single_label_classification"
  948. else:
  949. self.config.problem_type = "multi_label_classification"
  950. if self.config.problem_type == "regression":
  951. loss_fct = MSELoss()
  952. if self.num_labels == 1:
  953. loss = loss_fct(logits.squeeze(), labels.squeeze())
  954. else:
  955. loss = loss_fct(logits, labels)
  956. elif self.config.problem_type == "single_label_classification":
  957. loss_fct = CrossEntropyLoss()
  958. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  959. elif self.config.problem_type == "multi_label_classification":
  960. loss_fct = BCEWithLogitsLoss()
  961. loss = loss_fct(logits, labels)
  962. if not return_dict:
  963. output = (logits,) + outputs[1:]
  964. return ((loss,) + output) if loss is not None else output
  965. return SequenceClassifierOutput(
  966. loss=loss,
  967. logits=logits,
  968. hidden_states=outputs.hidden_states,
  969. attentions=outputs.attentions,
  970. )
  971. @auto_docstring
  972. class FunnelForMultipleChoice(FunnelPreTrainedModel):
  973. def __init__(self, config: FunnelConfig) -> None:
  974. super().__init__(config)
  975. self.funnel = FunnelBaseModel(config)
  976. self.classifier = FunnelClassificationHead(config, 1)
  977. # Initialize weights and apply final processing
  978. self.post_init()
  979. @auto_docstring
  980. def forward(
  981. self,
  982. input_ids: torch.Tensor | None = None,
  983. attention_mask: torch.Tensor | None = None,
  984. token_type_ids: torch.Tensor | None = None,
  985. inputs_embeds: torch.Tensor | None = None,
  986. labels: torch.Tensor | None = None,
  987. output_attentions: bool | None = None,
  988. output_hidden_states: bool | None = None,
  989. return_dict: bool | None = None,
  990. **kwargs,
  991. ) -> tuple | MultipleChoiceModelOutput:
  992. r"""
  993. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  994. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  995. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  996. `input_ids` above)
  997. """
  998. return_dict = return_dict if return_dict is not None else self.config.return_dict
  999. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  1000. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  1001. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  1002. token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  1003. inputs_embeds = (
  1004. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  1005. if inputs_embeds is not None
  1006. else None
  1007. )
  1008. outputs = self.funnel(
  1009. input_ids,
  1010. attention_mask=attention_mask,
  1011. token_type_ids=token_type_ids,
  1012. inputs_embeds=inputs_embeds,
  1013. output_attentions=output_attentions,
  1014. output_hidden_states=output_hidden_states,
  1015. return_dict=return_dict,
  1016. )
  1017. last_hidden_state = outputs[0]
  1018. pooled_output = last_hidden_state[:, 0]
  1019. logits = self.classifier(pooled_output)
  1020. reshaped_logits = logits.view(-1, num_choices)
  1021. loss = None
  1022. if labels is not None:
  1023. loss_fct = CrossEntropyLoss()
  1024. loss = loss_fct(reshaped_logits, labels)
  1025. if not return_dict:
  1026. output = (reshaped_logits,) + outputs[1:]
  1027. return ((loss,) + output) if loss is not None else output
  1028. return MultipleChoiceModelOutput(
  1029. loss=loss,
  1030. logits=reshaped_logits,
  1031. hidden_states=outputs.hidden_states,
  1032. attentions=outputs.attentions,
  1033. )
  1034. @auto_docstring
  1035. class FunnelForTokenClassification(FunnelPreTrainedModel):
  1036. def __init__(self, config: FunnelConfig) -> None:
  1037. super().__init__(config)
  1038. self.num_labels = config.num_labels
  1039. self.funnel = FunnelModel(config)
  1040. self.dropout = nn.Dropout(config.hidden_dropout)
  1041. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1042. # Initialize weights and apply final processing
  1043. self.post_init()
  1044. @auto_docstring
  1045. def forward(
  1046. self,
  1047. input_ids: torch.Tensor | None = None,
  1048. attention_mask: torch.Tensor | None = None,
  1049. token_type_ids: torch.Tensor | None = None,
  1050. inputs_embeds: torch.Tensor | None = None,
  1051. labels: torch.Tensor | None = None,
  1052. output_attentions: bool | None = None,
  1053. output_hidden_states: bool | None = None,
  1054. return_dict: bool | None = None,
  1055. **kwargs,
  1056. ) -> tuple | TokenClassifierOutput:
  1057. r"""
  1058. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1059. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1060. """
  1061. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1062. outputs = self.funnel(
  1063. input_ids,
  1064. attention_mask=attention_mask,
  1065. token_type_ids=token_type_ids,
  1066. inputs_embeds=inputs_embeds,
  1067. output_attentions=output_attentions,
  1068. output_hidden_states=output_hidden_states,
  1069. return_dict=return_dict,
  1070. )
  1071. last_hidden_state = outputs[0]
  1072. last_hidden_state = self.dropout(last_hidden_state)
  1073. logits = self.classifier(last_hidden_state)
  1074. loss = None
  1075. if labels is not None:
  1076. loss_fct = CrossEntropyLoss()
  1077. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1078. if not return_dict:
  1079. output = (logits,) + outputs[1:]
  1080. return ((loss,) + output) if loss is not None else output
  1081. return TokenClassifierOutput(
  1082. loss=loss,
  1083. logits=logits,
  1084. hidden_states=outputs.hidden_states,
  1085. attentions=outputs.attentions,
  1086. )
  1087. @auto_docstring
  1088. class FunnelForQuestionAnswering(FunnelPreTrainedModel):
  1089. def __init__(self, config: FunnelConfig) -> None:
  1090. super().__init__(config)
  1091. self.num_labels = config.num_labels
  1092. self.funnel = FunnelModel(config)
  1093. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  1094. # Initialize weights and apply final processing
  1095. self.post_init()
  1096. @auto_docstring
  1097. def forward(
  1098. self,
  1099. input_ids: torch.Tensor | None = None,
  1100. attention_mask: torch.Tensor | None = None,
  1101. token_type_ids: torch.Tensor | None = None,
  1102. inputs_embeds: torch.Tensor | None = None,
  1103. start_positions: torch.Tensor | None = None,
  1104. end_positions: torch.Tensor | None = None,
  1105. output_attentions: bool | None = None,
  1106. output_hidden_states: bool | None = None,
  1107. return_dict: bool | None = None,
  1108. **kwargs,
  1109. ) -> tuple | QuestionAnsweringModelOutput:
  1110. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1111. outputs = self.funnel(
  1112. input_ids,
  1113. attention_mask=attention_mask,
  1114. token_type_ids=token_type_ids,
  1115. inputs_embeds=inputs_embeds,
  1116. output_attentions=output_attentions,
  1117. output_hidden_states=output_hidden_states,
  1118. return_dict=return_dict,
  1119. )
  1120. last_hidden_state = outputs[0]
  1121. logits = self.qa_outputs(last_hidden_state)
  1122. start_logits, end_logits = logits.split(1, dim=-1)
  1123. start_logits = start_logits.squeeze(-1).contiguous()
  1124. end_logits = end_logits.squeeze(-1).contiguous()
  1125. total_loss = None
  1126. if start_positions is not None and end_positions is not None:
  1127. # If we are on multi-GPU, split add a dimension
  1128. if len(start_positions.size()) > 1:
  1129. start_positions = start_positions.squeze(-1)
  1130. if len(end_positions.size()) > 1:
  1131. end_positions = end_positions.squeeze(-1)
  1132. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1133. ignored_index = start_logits.size(1)
  1134. start_positions = start_positions.clamp(0, ignored_index)
  1135. end_positions = end_positions.clamp(0, ignored_index)
  1136. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1137. start_loss = loss_fct(start_logits, start_positions)
  1138. end_loss = loss_fct(end_logits, end_positions)
  1139. total_loss = (start_loss + end_loss) / 2
  1140. if not return_dict:
  1141. output = (start_logits, end_logits) + outputs[1:]
  1142. return ((total_loss,) + output) if total_loss is not None else output
  1143. return QuestionAnsweringModelOutput(
  1144. loss=total_loss,
  1145. start_logits=start_logits,
  1146. end_logits=end_logits,
  1147. hidden_states=outputs.hidden_states,
  1148. attentions=outputs.attentions,
  1149. )
  1150. __all__ = [
  1151. "FunnelBaseModel",
  1152. "FunnelForMaskedLM",
  1153. "FunnelForMultipleChoice",
  1154. "FunnelForPreTraining",
  1155. "FunnelForQuestionAnswering",
  1156. "FunnelForSequenceClassification",
  1157. "FunnelForTokenClassification",
  1158. "FunnelModel",
  1159. "FunnelPreTrainedModel",
  1160. ]