modular_eurobert.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. # Copyright 2025 Nicolas Boizard, Duarte M. Alves, Hippolyte Gisserot-Boukhlef and the EuroBert team. All rights reserved.
  2. #
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import torch
  16. from torch import nn
  17. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  18. from ...configuration_utils import strict
  19. from ...masking_utils import create_bidirectional_mask
  20. from ...modeling_outputs import BaseModelOutput, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput
  21. from ...modeling_rope_utils import RopeParameters
  22. from ...processing_utils import Unpack
  23. from ...utils import auto_docstring
  24. from ...utils.generic import TransformersKwargs, can_return_tuple
  25. from ..llama import LlamaConfig
  26. from ..llama.modeling_llama import LlamaAttention, LlamaModel, LlamaPreTrainedModel, LlamaRMSNorm
  27. @auto_docstring(checkpoint="EuroBERT/EuroBERT-210m")
  28. @strict
  29. class EuroBertConfig(LlamaConfig):
  30. r"""
  31. mask_token_id (`int`, *optional*, defaults to 128002):
  32. Mask token id.
  33. classifier_pooling (`str`, *optional*, defaults to `"late"`):
  34. The pooling strategy to use for the classifier. Can be one of ['bos', 'mean', 'late'].
  35. ```python
  36. >>> from transformers import EuroBertModel, EuroBertConfig
  37. >>> # Initializing a EuroBert eurobert-base style configuration
  38. >>> configuration = EuroBertConfig()
  39. >>> # Initializing a model from the eurobert-base style configuration
  40. >>> model = EuroBertModel(configuration)
  41. >>> # Accessing the model configuration
  42. >>> configuration = model.config
  43. ```"""
  44. model_type = "eurobert"
  45. vocab_size: int = 128256
  46. hidden_size: int = 768
  47. intermediate_size: int = 3072
  48. num_hidden_layers: int = 12
  49. num_attention_heads: int = 12
  50. num_key_value_heads: int | None = None
  51. hidden_act: str = "silu"
  52. max_position_embeddings: int = 8192
  53. initializer_range: float = 0.02
  54. rms_norm_eps: float = 1e-05
  55. bos_token_id: int | None = 128000
  56. eos_token_id: int | list[int] | None = 128001
  57. pad_token_id: int | None = 128001
  58. mask_token_id: int = 128002
  59. pretraining_tp: int = 1
  60. tie_word_embeddings: bool = False
  61. rope_parameters: RopeParameters | dict | None = None
  62. attention_bias: bool = False
  63. attention_dropout: int | float = 0.0
  64. mlp_bias: bool = False
  65. head_dim: int | None = None
  66. classifier_pooling: str = "late"
  67. def __post_init__(self, **kwargs):
  68. if self.num_key_value_heads is None:
  69. self.num_key_value_heads = self.num_attention_heads
  70. super().__post_init__(**kwargs)
  71. class EuroBertRMSNorm(LlamaRMSNorm):
  72. def __init__(self, hidden_size, eps=1e-5):
  73. super().__init__(hidden_size, eps)
  74. class EuroBertAttention(LlamaAttention):
  75. def __init__(self, config: EuroBertConfig, layer_idx: int):
  76. super().__init__(config, layer_idx)
  77. self.is_causal = False
  78. class EuroBertPreTrainedModel(LlamaPreTrainedModel):
  79. pass
  80. class EuroBertModel(LlamaModel):
  81. def forward(
  82. self,
  83. input_ids: torch.LongTensor = None,
  84. attention_mask: torch.Tensor | None = None,
  85. position_ids: torch.LongTensor | None = None,
  86. inputs_embeds: torch.FloatTensor | None = None,
  87. **kwargs: Unpack[TransformersKwargs],
  88. ) -> tuple | BaseModelOutput:
  89. if (input_ids is None) ^ (inputs_embeds is not None):
  90. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  91. if inputs_embeds is None:
  92. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  93. if position_ids is None:
  94. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
  95. bidirectional_mask = create_bidirectional_mask(
  96. config=self.config,
  97. inputs_embeds=inputs_embeds,
  98. attention_mask=attention_mask,
  99. )
  100. hidden_states = inputs_embeds
  101. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  102. for encoder_layer in self.layers[: self.config.num_hidden_layers]:
  103. hidden_states = encoder_layer(
  104. hidden_states,
  105. attention_mask=bidirectional_mask,
  106. position_embeddings=position_embeddings,
  107. position_ids=position_ids,
  108. **kwargs,
  109. )
  110. hidden_states = self.norm(hidden_states)
  111. return BaseModelOutput(
  112. last_hidden_state=hidden_states,
  113. )
  114. @auto_docstring
  115. class EuroBertForMaskedLM(EuroBertPreTrainedModel):
  116. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  117. _tp_plan = {"lm_head": "colwise_gather_output"}
  118. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  119. def __init__(self, config: EuroBertConfig):
  120. super().__init__(config)
  121. self.model = EuroBertModel(config)
  122. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, config.mlp_bias)
  123. # Initialize weights and apply final processing
  124. self.post_init()
  125. @can_return_tuple
  126. @auto_docstring
  127. def forward(
  128. self,
  129. input_ids: torch.LongTensor | None = None,
  130. attention_mask: torch.Tensor | None = None,
  131. position_ids: torch.LongTensor | None = None,
  132. inputs_embeds: torch.FloatTensor | None = None,
  133. labels: torch.LongTensor | None = None,
  134. **kwargs: Unpack[TransformersKwargs],
  135. ) -> tuple[torch.Tensor] | MaskedLMOutput:
  136. r"""
  137. Example:
  138. ```python
  139. >>> from transformers import AutoTokenizer, EuroBertForMaskedLM
  140. >>> model = EuroBertForMaskedLM.from_pretrained("EuroBERT/EuroBERT-210m")
  141. >>> tokenizer = AutoTokenizer.from_pretrained("EuroBERT/EuroBERT-210m")
  142. >>> text = "The capital of France is <|mask|>."
  143. >>> inputs = tokenizer(text, return_tensors="pt")
  144. >>> outputs = model(**inputs)
  145. >>> # To get predictions for the mask:
  146. >>> masked_index = inputs["input_ids"][0].tolist().index(tokenizer.mask_token_id)
  147. >>> predicted_token_id = outputs.logits[0, masked_index].argmax(axis=-1)
  148. >>> predicted_token = tokenizer.decode(predicted_token_id)
  149. >>> print("Predicted token:", predicted_token)
  150. Predicted token: Paris
  151. ```"""
  152. outputs: BaseModelOutput = self.model(
  153. input_ids=input_ids,
  154. attention_mask=attention_mask,
  155. position_ids=position_ids,
  156. inputs_embeds=inputs_embeds,
  157. **kwargs,
  158. )
  159. logits = self.lm_head(outputs.last_hidden_state)
  160. loss = None
  161. if labels is not None:
  162. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  163. return MaskedLMOutput(
  164. loss=loss,
  165. logits=logits,
  166. hidden_states=outputs.hidden_states,
  167. attentions=outputs.attentions,
  168. )
  169. @auto_docstring
  170. class EuroBertForSequenceClassification(EuroBertPreTrainedModel):
  171. def __init__(self, config: EuroBertConfig):
  172. super().__init__(config)
  173. self.num_labels = config.num_labels
  174. self.classifier_pooling = config.classifier_pooling
  175. self.model = EuroBertModel(config)
  176. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  177. self.activation = nn.GELU()
  178. self.classifier = nn.Linear(config.hidden_size, self.num_labels)
  179. self.post_init()
  180. @can_return_tuple
  181. @auto_docstring
  182. def forward(
  183. self,
  184. input_ids: torch.LongTensor | None = None,
  185. attention_mask: torch.Tensor | None = None,
  186. position_ids: torch.LongTensor | None = None,
  187. inputs_embeds: torch.FloatTensor | None = None,
  188. labels: torch.LongTensor | None = None,
  189. **kwargs: Unpack[TransformersKwargs],
  190. ) -> tuple[torch.Tensor] | SequenceClassifierOutput:
  191. encoder_output = self.model(
  192. input_ids,
  193. attention_mask=attention_mask,
  194. position_ids=position_ids,
  195. inputs_embeds=inputs_embeds,
  196. **kwargs,
  197. )
  198. last_hidden_state = encoder_output[0]
  199. if self.classifier_pooling in ["bos", "mean"]:
  200. if self.classifier_pooling == "bos":
  201. pooled_output = last_hidden_state[:, 0]
  202. elif self.classifier_pooling == "mean":
  203. if attention_mask is None:
  204. pooled_output = last_hidden_state.mean(dim=1)
  205. else:
  206. attention_mask = attention_mask.to(last_hidden_state.device)
  207. pooled_output = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1)
  208. pooled_output /= attention_mask.sum(dim=1, keepdim=True)
  209. pooled_output = self.dense(pooled_output)
  210. pooled_output = self.activation(pooled_output)
  211. logits = self.classifier(pooled_output)
  212. elif self.classifier_pooling == "late":
  213. x = self.dense(last_hidden_state)
  214. x = self.activation(x)
  215. logits = self.classifier(x)
  216. if attention_mask is None:
  217. logits = logits.mean(dim=1)
  218. else:
  219. attention_mask = attention_mask.to(logits.device)
  220. logits = (logits * attention_mask.unsqueeze(-1)).sum(dim=1)
  221. logits /= attention_mask.sum(dim=1, keepdim=True)
  222. loss = None
  223. if labels is not None:
  224. labels = labels.to(logits.device)
  225. if self.config.problem_type is None:
  226. if self.num_labels == 1:
  227. self.config.problem_type = "regression"
  228. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  229. self.config.problem_type = "single_label_classification"
  230. else:
  231. self.config.problem_type = "multi_label_classification"
  232. if self.config.problem_type == "regression":
  233. loss_fct = MSELoss()
  234. if self.num_labels == 1:
  235. loss = loss_fct(logits.squeeze(), labels.squeeze())
  236. else:
  237. loss = loss_fct(logits, labels)
  238. elif self.config.problem_type == "single_label_classification":
  239. loss_fct = CrossEntropyLoss()
  240. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  241. elif self.config.problem_type == "multi_label_classification":
  242. loss_fct = BCEWithLogitsLoss()
  243. loss = loss_fct(logits, labels)
  244. return SequenceClassifierOutput(
  245. loss=loss,
  246. logits=logits,
  247. hidden_states=encoder_output.hidden_states,
  248. attentions=encoder_output.attentions,
  249. )
  250. @auto_docstring
  251. class EuroBertForTokenClassification(EuroBertPreTrainedModel):
  252. def __init__(self, config: EuroBertConfig):
  253. super().__init__(config)
  254. self.num_labels = config.num_labels
  255. self.model = EuroBertModel(config)
  256. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  257. self.post_init()
  258. def get_input_embeddings(self):
  259. return self.model.embed_tokens
  260. def set_input_embeddings(self, value):
  261. self.model.embed_tokens = value
  262. @can_return_tuple
  263. @auto_docstring
  264. def forward(
  265. self,
  266. input_ids: torch.LongTensor | None = None,
  267. attention_mask: torch.Tensor | None = None,
  268. position_ids: torch.LongTensor | None = None,
  269. inputs_embeds: torch.FloatTensor | None = None,
  270. labels: torch.LongTensor | None = None,
  271. **kwargs: Unpack[TransformersKwargs],
  272. ) -> tuple | TokenClassifierOutput:
  273. r"""
  274. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  275. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  276. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  277. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  278. """
  279. outputs = self.model(
  280. input_ids,
  281. attention_mask=attention_mask,
  282. position_ids=position_ids,
  283. inputs_embeds=inputs_embeds,
  284. **kwargs,
  285. )
  286. sequence_output = outputs[0]
  287. logits = self.classifier(sequence_output)
  288. loss = None
  289. if labels is not None:
  290. loss_fct = CrossEntropyLoss()
  291. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  292. return TokenClassifierOutput(
  293. loss=loss,
  294. logits=logits,
  295. hidden_states=outputs.hidden_states,
  296. attentions=outputs.attentions,
  297. )
  298. __all__ = [
  299. "EuroBertConfig",
  300. "EuroBertPreTrainedModel",
  301. "EuroBertModel",
  302. "EuroBertForMaskedLM",
  303. "EuroBertForSequenceClassification",
  304. "EuroBertForTokenClassification",
  305. ]