modeling_eurobert.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/eurobert/modular_eurobert.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_eurobert.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 Nicolas Boizard, Duarte M. Alves, Hippolyte Gisserot-Boukhlef and the EuroBert team. All rights reserved.
  8. #
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. from collections.abc import Callable
  22. from typing import Optional
  23. import torch
  24. from torch import nn
  25. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  26. from ...activations import ACT2FN
  27. from ...cache_utils import Cache
  28. from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
  29. from ...masking_utils import create_bidirectional_mask
  30. from ...modeling_layers import GradientCheckpointingLayer
  31. from ...modeling_outputs import BaseModelOutput, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput
  32. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  33. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  34. from ...processing_utils import Unpack
  35. from ...utils import auto_docstring
  36. from ...utils.generic import TransformersKwargs, can_return_tuple, maybe_autocast, merge_with_config_defaults
  37. from ...utils.output_capturing import capture_outputs
  38. from .configuration_eurobert import EuroBertConfig
  39. @use_kernel_forward_from_hub("RMSNorm")
  40. class EuroBertRMSNorm(nn.Module):
  41. def __init__(self, hidden_size, eps=1e-5) -> None:
  42. """
  43. EuroBertRMSNorm is equivalent to T5LayerNorm
  44. """
  45. super().__init__()
  46. self.weight = nn.Parameter(torch.ones(hidden_size))
  47. self.variance_epsilon = eps
  48. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  49. input_dtype = hidden_states.dtype
  50. hidden_states = hidden_states.to(torch.float32)
  51. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  52. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  53. return self.weight * hidden_states.to(input_dtype)
  54. def extra_repr(self):
  55. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  56. def rotate_half(x):
  57. """Rotates half the hidden dims of the input."""
  58. x1 = x[..., : x.shape[-1] // 2]
  59. x2 = x[..., x.shape[-1] // 2 :]
  60. return torch.cat((-x2, x1), dim=-1)
  61. @use_kernel_func_from_hub("rotary_pos_emb")
  62. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  63. """Applies Rotary Position Embedding to the query and key tensors.
  64. Args:
  65. q (`torch.Tensor`): The query tensor.
  66. k (`torch.Tensor`): The key tensor.
  67. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  68. sin (`torch.Tensor`): The sine part of the rotary embedding.
  69. unsqueeze_dim (`int`, *optional*, defaults to 1):
  70. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  71. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  72. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  73. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  74. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  75. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  76. Returns:
  77. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  78. """
  79. cos = cos.unsqueeze(unsqueeze_dim)
  80. sin = sin.unsqueeze(unsqueeze_dim)
  81. q_embed = (q * cos) + (rotate_half(q) * sin)
  82. k_embed = (k * cos) + (rotate_half(k) * sin)
  83. return q_embed, k_embed
  84. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  85. """
  86. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  87. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  88. """
  89. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  90. if n_rep == 1:
  91. return hidden_states
  92. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  93. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  94. def eager_attention_forward(
  95. module: nn.Module,
  96. query: torch.Tensor,
  97. key: torch.Tensor,
  98. value: torch.Tensor,
  99. attention_mask: torch.Tensor | None,
  100. scaling: float,
  101. dropout: float = 0.0,
  102. **kwargs: Unpack[TransformersKwargs],
  103. ):
  104. key_states = repeat_kv(key, module.num_key_value_groups)
  105. value_states = repeat_kv(value, module.num_key_value_groups)
  106. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  107. if attention_mask is not None:
  108. attn_weights = attn_weights + attention_mask
  109. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  110. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  111. attn_output = torch.matmul(attn_weights, value_states)
  112. attn_output = attn_output.transpose(1, 2).contiguous()
  113. return attn_output, attn_weights
  114. @use_kernelized_func(apply_rotary_pos_emb)
  115. class EuroBertAttention(nn.Module):
  116. """Multi-headed attention from 'Attention Is All You Need' paper"""
  117. def __init__(self, config: EuroBertConfig, layer_idx: int):
  118. super().__init__()
  119. self.config = config
  120. self.layer_idx = layer_idx
  121. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  122. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  123. self.scaling = self.head_dim**-0.5
  124. self.attention_dropout = config.attention_dropout
  125. self.is_causal = False
  126. self.q_proj = nn.Linear(
  127. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  128. )
  129. self.k_proj = nn.Linear(
  130. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  131. )
  132. self.v_proj = nn.Linear(
  133. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  134. )
  135. self.o_proj = nn.Linear(
  136. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  137. )
  138. def forward(
  139. self,
  140. hidden_states: torch.Tensor,
  141. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  142. attention_mask: torch.Tensor | None = None,
  143. past_key_values: Cache | None = None,
  144. **kwargs: Unpack[TransformersKwargs],
  145. ) -> tuple[torch.Tensor, torch.Tensor]:
  146. input_shape = hidden_states.shape[:-1]
  147. hidden_shape = (*input_shape, -1, self.head_dim)
  148. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  149. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  150. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  151. cos, sin = position_embeddings
  152. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  153. if past_key_values is not None:
  154. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  155. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  156. self.config._attn_implementation, eager_attention_forward
  157. )
  158. attn_output, attn_weights = attention_interface(
  159. self,
  160. query_states,
  161. key_states,
  162. value_states,
  163. attention_mask,
  164. dropout=0.0 if not self.training else self.attention_dropout,
  165. scaling=self.scaling,
  166. **kwargs,
  167. )
  168. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  169. attn_output = self.o_proj(attn_output)
  170. return attn_output, attn_weights
  171. class EuroBertMLP(nn.Module):
  172. def __init__(self, config):
  173. super().__init__()
  174. self.config = config
  175. self.hidden_size = config.hidden_size
  176. self.intermediate_size = config.intermediate_size
  177. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  178. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  179. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
  180. self.act_fn = ACT2FN[config.hidden_act]
  181. def forward(self, x):
  182. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  183. return down_proj
  184. class EuroBertDecoderLayer(GradientCheckpointingLayer):
  185. def __init__(self, config: EuroBertConfig, layer_idx: int):
  186. super().__init__()
  187. self.hidden_size = config.hidden_size
  188. self.self_attn = EuroBertAttention(config=config, layer_idx=layer_idx)
  189. self.mlp = EuroBertMLP(config)
  190. self.input_layernorm = EuroBertRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  191. self.post_attention_layernorm = EuroBertRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  192. def forward(
  193. self,
  194. hidden_states: torch.Tensor,
  195. attention_mask: torch.Tensor | None = None,
  196. position_ids: torch.LongTensor | None = None,
  197. past_key_values: Cache | None = None,
  198. use_cache: bool | None = False,
  199. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  200. **kwargs: Unpack[TransformersKwargs],
  201. ) -> torch.Tensor:
  202. residual = hidden_states
  203. hidden_states = self.input_layernorm(hidden_states)
  204. # Self Attention
  205. hidden_states, _ = self.self_attn(
  206. hidden_states=hidden_states,
  207. attention_mask=attention_mask,
  208. position_ids=position_ids,
  209. past_key_values=past_key_values,
  210. use_cache=use_cache,
  211. position_embeddings=position_embeddings,
  212. **kwargs,
  213. )
  214. hidden_states = residual + hidden_states
  215. # Fully Connected
  216. residual = hidden_states
  217. hidden_states = self.post_attention_layernorm(hidden_states)
  218. hidden_states = self.mlp(hidden_states)
  219. hidden_states = residual + hidden_states
  220. return hidden_states
  221. @auto_docstring
  222. class EuroBertPreTrainedModel(PreTrainedModel):
  223. config: EuroBertConfig
  224. base_model_prefix = "model"
  225. supports_gradient_checkpointing = True
  226. _no_split_modules = ["EuroBertDecoderLayer"]
  227. _skip_keys_device_placement = ["past_key_values"]
  228. _supports_flash_attn = True
  229. _supports_sdpa = True
  230. _supports_flex_attn = True
  231. _can_compile_fullgraph = True
  232. _supports_attention_backend = True
  233. _can_record_outputs = {
  234. "hidden_states": EuroBertDecoderLayer,
  235. "attentions": EuroBertAttention,
  236. }
  237. class EuroBertRotaryEmbedding(nn.Module):
  238. inv_freq: torch.Tensor # fix linting for `register_buffer`
  239. def __init__(self, config: EuroBertConfig, device=None):
  240. super().__init__()
  241. self.max_seq_len_cached = config.max_position_embeddings
  242. self.original_max_seq_len = config.max_position_embeddings
  243. self.config = config
  244. self.rope_type = self.config.rope_parameters["rope_type"]
  245. rope_init_fn: Callable = self.compute_default_rope_parameters
  246. if self.rope_type != "default":
  247. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  248. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  249. self.register_buffer("inv_freq", inv_freq, persistent=False)
  250. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  251. @staticmethod
  252. def compute_default_rope_parameters(
  253. config: EuroBertConfig | None = None,
  254. device: Optional["torch.device"] = None,
  255. seq_len: int | None = None,
  256. ) -> tuple["torch.Tensor", float]:
  257. """
  258. Computes the inverse frequencies according to the original RoPE implementation
  259. Args:
  260. config ([`~transformers.PreTrainedConfig`]):
  261. The model configuration.
  262. device (`torch.device`):
  263. The device to use for initialization of the inverse frequencies.
  264. seq_len (`int`, *optional*):
  265. The current sequence length. Unused for this type of RoPE.
  266. Returns:
  267. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  268. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  269. """
  270. base = config.rope_parameters["rope_theta"]
  271. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  272. attention_factor = 1.0 # Unused in this type of RoPE
  273. # Compute the inverse frequencies
  274. inv_freq = 1.0 / (
  275. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  276. )
  277. return inv_freq, attention_factor
  278. @torch.no_grad()
  279. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  280. def forward(self, x, position_ids):
  281. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  282. position_ids_expanded = position_ids[:, None, :].float()
  283. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  284. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  285. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  286. emb = torch.cat((freqs, freqs), dim=-1)
  287. cos = emb.cos() * self.attention_scaling
  288. sin = emb.sin() * self.attention_scaling
  289. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  290. @auto_docstring
  291. class EuroBertModel(EuroBertPreTrainedModel):
  292. def __init__(self, config: EuroBertConfig):
  293. super().__init__(config)
  294. self.padding_idx = config.pad_token_id
  295. self.vocab_size = config.vocab_size
  296. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  297. self.layers = nn.ModuleList(
  298. [EuroBertDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  299. )
  300. self.norm = EuroBertRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  301. self.rotary_emb = EuroBertRotaryEmbedding(config=config)
  302. self.gradient_checkpointing = False
  303. # Initialize weights and apply final processing
  304. self.post_init()
  305. @merge_with_config_defaults
  306. @capture_outputs
  307. @auto_docstring
  308. def forward(
  309. self,
  310. input_ids: torch.LongTensor = None,
  311. attention_mask: torch.Tensor | None = None,
  312. position_ids: torch.LongTensor | None = None,
  313. inputs_embeds: torch.FloatTensor | None = None,
  314. **kwargs: Unpack[TransformersKwargs],
  315. ) -> tuple | BaseModelOutput:
  316. if (input_ids is None) ^ (inputs_embeds is not None):
  317. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  318. if inputs_embeds is None:
  319. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  320. if position_ids is None:
  321. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
  322. bidirectional_mask = create_bidirectional_mask(
  323. config=self.config,
  324. inputs_embeds=inputs_embeds,
  325. attention_mask=attention_mask,
  326. )
  327. hidden_states = inputs_embeds
  328. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  329. for encoder_layer in self.layers[: self.config.num_hidden_layers]:
  330. hidden_states = encoder_layer(
  331. hidden_states,
  332. attention_mask=bidirectional_mask,
  333. position_embeddings=position_embeddings,
  334. position_ids=position_ids,
  335. **kwargs,
  336. )
  337. hidden_states = self.norm(hidden_states)
  338. return BaseModelOutput(
  339. last_hidden_state=hidden_states,
  340. )
  341. @auto_docstring
  342. class EuroBertForMaskedLM(EuroBertPreTrainedModel):
  343. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  344. _tp_plan = {"lm_head": "colwise_gather_output"}
  345. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  346. def __init__(self, config: EuroBertConfig):
  347. super().__init__(config)
  348. self.model = EuroBertModel(config)
  349. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, config.mlp_bias)
  350. # Initialize weights and apply final processing
  351. self.post_init()
  352. @can_return_tuple
  353. @auto_docstring
  354. def forward(
  355. self,
  356. input_ids: torch.LongTensor | None = None,
  357. attention_mask: torch.Tensor | None = None,
  358. position_ids: torch.LongTensor | None = None,
  359. inputs_embeds: torch.FloatTensor | None = None,
  360. labels: torch.LongTensor | None = None,
  361. **kwargs: Unpack[TransformersKwargs],
  362. ) -> tuple[torch.Tensor] | MaskedLMOutput:
  363. r"""
  364. Example:
  365. ```python
  366. >>> from transformers import AutoTokenizer, EuroBertForMaskedLM
  367. >>> model = EuroBertForMaskedLM.from_pretrained("EuroBERT/EuroBERT-210m")
  368. >>> tokenizer = AutoTokenizer.from_pretrained("EuroBERT/EuroBERT-210m")
  369. >>> text = "The capital of France is <|mask|>."
  370. >>> inputs = tokenizer(text, return_tensors="pt")
  371. >>> outputs = model(**inputs)
  372. >>> # To get predictions for the mask:
  373. >>> masked_index = inputs["input_ids"][0].tolist().index(tokenizer.mask_token_id)
  374. >>> predicted_token_id = outputs.logits[0, masked_index].argmax(axis=-1)
  375. >>> predicted_token = tokenizer.decode(predicted_token_id)
  376. >>> print("Predicted token:", predicted_token)
  377. Predicted token: Paris
  378. ```"""
  379. outputs: BaseModelOutput = self.model(
  380. input_ids=input_ids,
  381. attention_mask=attention_mask,
  382. position_ids=position_ids,
  383. inputs_embeds=inputs_embeds,
  384. **kwargs,
  385. )
  386. logits = self.lm_head(outputs.last_hidden_state)
  387. loss = None
  388. if labels is not None:
  389. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  390. return MaskedLMOutput(
  391. loss=loss,
  392. logits=logits,
  393. hidden_states=outputs.hidden_states,
  394. attentions=outputs.attentions,
  395. )
  396. @auto_docstring
  397. class EuroBertForSequenceClassification(EuroBertPreTrainedModel):
  398. def __init__(self, config: EuroBertConfig):
  399. super().__init__(config)
  400. self.num_labels = config.num_labels
  401. self.classifier_pooling = config.classifier_pooling
  402. self.model = EuroBertModel(config)
  403. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  404. self.activation = nn.GELU()
  405. self.classifier = nn.Linear(config.hidden_size, self.num_labels)
  406. self.post_init()
  407. @can_return_tuple
  408. @auto_docstring
  409. def forward(
  410. self,
  411. input_ids: torch.LongTensor | None = None,
  412. attention_mask: torch.Tensor | None = None,
  413. position_ids: torch.LongTensor | None = None,
  414. inputs_embeds: torch.FloatTensor | None = None,
  415. labels: torch.LongTensor | None = None,
  416. **kwargs: Unpack[TransformersKwargs],
  417. ) -> tuple[torch.Tensor] | SequenceClassifierOutput:
  418. encoder_output = self.model(
  419. input_ids,
  420. attention_mask=attention_mask,
  421. position_ids=position_ids,
  422. inputs_embeds=inputs_embeds,
  423. **kwargs,
  424. )
  425. last_hidden_state = encoder_output[0]
  426. if self.classifier_pooling in ["bos", "mean"]:
  427. if self.classifier_pooling == "bos":
  428. pooled_output = last_hidden_state[:, 0]
  429. elif self.classifier_pooling == "mean":
  430. if attention_mask is None:
  431. pooled_output = last_hidden_state.mean(dim=1)
  432. else:
  433. attention_mask = attention_mask.to(last_hidden_state.device)
  434. pooled_output = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1)
  435. pooled_output /= attention_mask.sum(dim=1, keepdim=True)
  436. pooled_output = self.dense(pooled_output)
  437. pooled_output = self.activation(pooled_output)
  438. logits = self.classifier(pooled_output)
  439. elif self.classifier_pooling == "late":
  440. x = self.dense(last_hidden_state)
  441. x = self.activation(x)
  442. logits = self.classifier(x)
  443. if attention_mask is None:
  444. logits = logits.mean(dim=1)
  445. else:
  446. attention_mask = attention_mask.to(logits.device)
  447. logits = (logits * attention_mask.unsqueeze(-1)).sum(dim=1)
  448. logits /= attention_mask.sum(dim=1, keepdim=True)
  449. loss = None
  450. if labels is not None:
  451. labels = labels.to(logits.device)
  452. if self.config.problem_type is None:
  453. if self.num_labels == 1:
  454. self.config.problem_type = "regression"
  455. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  456. self.config.problem_type = "single_label_classification"
  457. else:
  458. self.config.problem_type = "multi_label_classification"
  459. if self.config.problem_type == "regression":
  460. loss_fct = MSELoss()
  461. if self.num_labels == 1:
  462. loss = loss_fct(logits.squeeze(), labels.squeeze())
  463. else:
  464. loss = loss_fct(logits, labels)
  465. elif self.config.problem_type == "single_label_classification":
  466. loss_fct = CrossEntropyLoss()
  467. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  468. elif self.config.problem_type == "multi_label_classification":
  469. loss_fct = BCEWithLogitsLoss()
  470. loss = loss_fct(logits, labels)
  471. return SequenceClassifierOutput(
  472. loss=loss,
  473. logits=logits,
  474. hidden_states=encoder_output.hidden_states,
  475. attentions=encoder_output.attentions,
  476. )
  477. @auto_docstring
  478. class EuroBertForTokenClassification(EuroBertPreTrainedModel):
  479. def __init__(self, config: EuroBertConfig):
  480. super().__init__(config)
  481. self.num_labels = config.num_labels
  482. self.model = EuroBertModel(config)
  483. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  484. self.post_init()
  485. def get_input_embeddings(self):
  486. return self.model.embed_tokens
  487. def set_input_embeddings(self, value):
  488. self.model.embed_tokens = value
  489. @can_return_tuple
  490. @auto_docstring
  491. def forward(
  492. self,
  493. input_ids: torch.LongTensor | None = None,
  494. attention_mask: torch.Tensor | None = None,
  495. position_ids: torch.LongTensor | None = None,
  496. inputs_embeds: torch.FloatTensor | None = None,
  497. labels: torch.LongTensor | None = None,
  498. **kwargs: Unpack[TransformersKwargs],
  499. ) -> tuple | TokenClassifierOutput:
  500. r"""
  501. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  502. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  503. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  504. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  505. """
  506. outputs = self.model(
  507. input_ids,
  508. attention_mask=attention_mask,
  509. position_ids=position_ids,
  510. inputs_embeds=inputs_embeds,
  511. **kwargs,
  512. )
  513. sequence_output = outputs[0]
  514. logits = self.classifier(sequence_output)
  515. loss = None
  516. if labels is not None:
  517. loss_fct = CrossEntropyLoss()
  518. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  519. return TokenClassifierOutput(
  520. loss=loss,
  521. logits=logits,
  522. hidden_states=outputs.hidden_states,
  523. attentions=outputs.attentions,
  524. )
  525. __all__ = [
  526. "EuroBertPreTrainedModel",
  527. "EuroBertModel",
  528. "EuroBertForMaskedLM",
  529. "EuroBertForSequenceClassification",
  530. "EuroBertForTokenClassification",
  531. ]