modular_modernbert.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921
  1. # Copyright 2024 Answer.AI, LightOn, and contributors, and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import math
  16. from typing import Literal, Optional
  17. import torch
  18. from huggingface_hub.dataclasses import strict
  19. from torch import nn
  20. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  21. from ... import initialization as init
  22. from ...activations import ACT2FN
  23. from ...configuration_utils import PreTrainedConfig
  24. from ...integrations import use_kernel_func_from_hub, use_kernelized_func
  25. from ...masking_utils import create_bidirectional_mask, create_bidirectional_sliding_window_mask
  26. from ...modeling_layers import GradientCheckpointingLayer
  27. from ...modeling_outputs import (
  28. BaseModelOutput,
  29. MaskedLMOutput,
  30. MultipleChoiceModelOutput,
  31. QuestionAnsweringModelOutput,
  32. SequenceClassifierOutput,
  33. TokenClassifierOutput,
  34. )
  35. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
  36. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  37. from ...processing_utils import Unpack
  38. from ...utils import TransformersKwargs, auto_docstring, logging
  39. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  40. from ...utils.output_capturing import capture_outputs
  41. from ..align.modeling_align import eager_attention_forward
  42. from ..gemma3.modeling_gemma3 import Gemma3RotaryEmbedding, rotate_half
  43. logger = logging.get_logger(__name__)
  44. @auto_docstring(checkpoint="answerdotai/ModernBERT-base")
  45. @strict
  46. class ModernBertConfig(PreTrainedConfig):
  47. r"""
  48. initializer_cutoff_factor (`float`, *optional*, defaults to 2.0):
  49. The cutoff factor for the truncated_normal_initializer for initializing all weight matrices.
  50. norm_eps (`float`, *optional*, defaults to 1e-05):
  51. The epsilon used by the rms normalization layers.
  52. norm_bias (`bool`, *optional*, defaults to `False`):
  53. Whether to use bias in the normalization layers.
  54. local_attention (`int`, *optional*, defaults to 128):
  55. The window size for local attention.
  56. mlp_dropout (`float`, *optional*, defaults to 0.0):
  57. The dropout ratio for the MLP layers.
  58. decoder_bias (`bool`, *optional*, defaults to `True`):
  59. Whether to use bias in the decoder layers.
  60. classifier_pooling (`str`, *optional*, defaults to `"cls"`):
  61. The pooling method for the classifier. Should be either `"cls"` or `"mean"`. In local attention layers, the
  62. CLS token doesn't attend to all tokens on long sequences.
  63. classifier_bias (`bool`, *optional*, defaults to `False`):
  64. Whether to use bias in the classifier.
  65. classifier_activation (`str`, *optional*, defaults to `"gelu"`):
  66. The activation function for the classifier.
  67. deterministic_flash_attn (`bool`, *optional*, defaults to `False`):
  68. Whether to use deterministic flash attention. If `False`, inference will be faster but not deterministic.
  69. sparse_prediction (`bool`, *optional*, defaults to `False`):
  70. Whether to use sparse prediction for the masked language model instead of returning the full dense logits.
  71. sparse_pred_ignore_index (`int`, *optional*, defaults to -100):
  72. The index to ignore for the sparse prediction.
  73. Examples:
  74. ```python
  75. >>> from transformers import ModernBertModel, ModernBertConfig
  76. >>> # Initializing a ModernBert style configuration
  77. >>> configuration = ModernBertConfig()
  78. >>> # Initializing a model from the modernbert-base style configuration
  79. >>> model = ModernBertModel(configuration)
  80. >>> # Accessing the model configuration
  81. >>> configuration = model.config
  82. ```"""
  83. model_type = "modernbert"
  84. keys_to_ignore_at_inference = ["past_key_values"]
  85. default_theta = {"global": 160_000.0, "local": 10_000.0}
  86. vocab_size: int = 50368
  87. hidden_size: int = 768
  88. intermediate_size: int = 1152
  89. num_hidden_layers: int = 22
  90. num_attention_heads: int = 12
  91. hidden_activation: str = "gelu"
  92. max_position_embeddings: int = 8192
  93. initializer_range: float = 0.02
  94. initializer_cutoff_factor: float = 2.0
  95. norm_eps: float = 1e-5
  96. norm_bias: bool = False
  97. pad_token_id: int | None = 50283
  98. eos_token_id: int | list[int] | None = 50282
  99. bos_token_id: int | None = 50281
  100. cls_token_id: int | None = 50281
  101. sep_token_id: int | None = 50282
  102. attention_bias: bool = False
  103. attention_dropout: float | int = 0.0
  104. layer_types: list[str] | None = None
  105. rope_parameters: dict[Literal["full_attention", "sliding_attention"], dict] | None = None
  106. local_attention: int = 128
  107. embedding_dropout: float | int = 0.0
  108. mlp_bias: bool = False
  109. mlp_dropout: float | int = 0.0
  110. decoder_bias: bool = True
  111. classifier_pooling: Literal["cls", "mean"] = "cls"
  112. classifier_dropout: float | int = 0.0
  113. classifier_bias: bool = False
  114. classifier_activation: str = "gelu"
  115. deterministic_flash_attn: bool = False
  116. sparse_prediction: bool = False
  117. sparse_pred_ignore_index: int = -100
  118. tie_word_embeddings: bool = True
  119. def __post_init__(self, **kwargs):
  120. # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
  121. global_attn_every_n_layers = kwargs.get("global_attn_every_n_layers", 3)
  122. if self.layer_types is None:
  123. self.layer_types = [
  124. "sliding_attention" if bool(i % global_attn_every_n_layers) else "full_attention"
  125. for i in range(self.num_hidden_layers)
  126. ]
  127. super().__post_init__(**kwargs)
  128. def convert_rope_params_to_dict(self, **kwargs):
  129. rope_scaling = kwargs.pop("rope_scaling", None)
  130. # Try to set `rope_scaling` if available, otherwise use `rope_parameters`. If we find `rope_parameters`
  131. # as arg in the inputs, we can safely assume that it is in the new format. New naming used -> new format
  132. default_rope_params = {
  133. "sliding_attention": {"rope_type": "default"},
  134. "full_attention": {"rope_type": "default"},
  135. }
  136. self.rope_parameters = self.rope_parameters if self.rope_parameters is not None else default_rope_params
  137. if rope_scaling is not None:
  138. self.rope_parameters["full_attention"].update(rope_scaling)
  139. self.rope_parameters["sliding_attention"].update(rope_scaling)
  140. # Set default values if not present
  141. if self.rope_parameters.get("full_attention") is None:
  142. self.rope_parameters["full_attention"] = {"rope_type": "default"}
  143. self.rope_parameters["full_attention"].setdefault(
  144. "rope_theta", kwargs.pop("global_rope_theta", self.default_theta["global"])
  145. )
  146. if self.rope_parameters.get("sliding_attention") is None:
  147. self.rope_parameters["sliding_attention"] = {"rope_type": "default"}
  148. self.rope_parameters["sliding_attention"].setdefault(
  149. "rope_theta", kwargs.pop("local_rope_theta", self.default_theta["local"])
  150. )
  151. # Standardize and validate the correctness of rotary position embeddings parameters
  152. self.standardize_rope_params()
  153. return kwargs
  154. def to_dict(self):
  155. output = super().to_dict()
  156. output.pop("reference_compile", None)
  157. return output
  158. @property
  159. def sliding_window(self):
  160. """Half-window size: `local_attention` is the total window, so we divide by 2."""
  161. return self.local_attention // 2
  162. @sliding_window.setter
  163. def sliding_window(self, value):
  164. """Set sliding_window by updating local_attention to 2 * value."""
  165. self.local_attention = value * 2
  166. class ModernBertEmbeddings(nn.Module):
  167. """
  168. Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
  169. """
  170. def __init__(self, config: ModernBertConfig):
  171. super().__init__()
  172. self.config = config
  173. self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  174. self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
  175. self.drop = nn.Dropout(config.embedding_dropout)
  176. def forward(
  177. self, input_ids: torch.LongTensor | None = None, inputs_embeds: torch.Tensor | None = None
  178. ) -> torch.Tensor:
  179. if inputs_embeds is not None:
  180. hidden_states = self.drop(self.norm(inputs_embeds))
  181. else:
  182. hidden_states = self.drop(self.norm(self.tok_embeddings(input_ids)))
  183. return hidden_states
  184. class ModernBertMLP(nn.Module):
  185. """Applies the GLU at the end of each ModernBERT layer.
  186. Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
  187. and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality.
  188. """
  189. def __init__(self, config: ModernBertConfig):
  190. super().__init__()
  191. self.config = config
  192. self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias)
  193. self.act = ACT2FN[config.hidden_activation]
  194. self.drop = nn.Dropout(config.mlp_dropout)
  195. self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias)
  196. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  197. input, gate = self.Wi(hidden_states).chunk(2, dim=-1)
  198. return self.Wo(self.drop(self.act(input) * gate))
  199. class ModernBertRotaryEmbedding(Gemma3RotaryEmbedding):
  200. def __init__(self, config: ModernBertConfig, device=None):
  201. super().__init__(config, device)
  202. @staticmethod
  203. def compute_default_rope_parameters(
  204. config: ModernBertConfig | None = None,
  205. device: Optional["torch.device"] = None,
  206. seq_len: int | None = None,
  207. layer_type: str | None = None,
  208. ) -> tuple["torch.Tensor", float]:
  209. return super().compute_default_rope_parameters(config, device, seq_len, layer_type)
  210. @use_kernel_func_from_hub("rotary_pos_emb")
  211. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  212. """Applies Rotary Position Embedding to the query and key tensors.
  213. Args:
  214. q (`torch.Tensor`): The query tensor.
  215. k (`torch.Tensor`): The key tensor.
  216. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  217. sin (`torch.Tensor`): The sine part of the rotary embedding.
  218. unsqueeze_dim (`int`, *optional*, defaults to 1):
  219. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  220. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  221. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  222. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  223. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  224. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  225. Returns:
  226. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  227. """
  228. original_dtype = q.dtype
  229. cos = cos.unsqueeze(unsqueeze_dim)
  230. sin = sin.unsqueeze(unsqueeze_dim)
  231. q_embed = (q.float() * cos) + (rotate_half(q.float()) * sin)
  232. k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin)
  233. return q_embed.to(original_dtype), k_embed.to(original_dtype)
  234. @use_kernelized_func(apply_rotary_pos_emb)
  235. class ModernBertAttention(nn.Module):
  236. """Performs multi-headed self attention on a batch of unpadded sequences.
  237. If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput.
  238. If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel,
  239. which requires padding and unpadding inputs, adding some overhead.
  240. See `forward` method for additional details.
  241. """
  242. def __init__(self, config: ModernBertConfig, layer_idx: int | None = None):
  243. super().__init__()
  244. self.config = config
  245. self.layer_idx = layer_idx
  246. if config.hidden_size % config.num_attention_heads != 0:
  247. raise ValueError(
  248. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})"
  249. )
  250. self.attention_dropout = config.attention_dropout
  251. self.deterministic_flash_attn = config.deterministic_flash_attn
  252. self.head_dim = config.hidden_size // config.num_attention_heads
  253. self.Wqkv = nn.Linear(
  254. config.hidden_size, 3 * self.head_dim * config.num_attention_heads, bias=config.attention_bias
  255. )
  256. if config.layer_types[layer_idx] == "sliding_attention":
  257. # config.sliding_window = local_attention // 2 (half-window size, e.g. 64 for local_attention=128)
  258. # +1 is needed because flash attention sets inclusive boundaries (see modeling_flash_attention_utils.py)
  259. self.sliding_window = config.sliding_window + 1
  260. else:
  261. self.sliding_window = None
  262. self.is_causal = False
  263. self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
  264. self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
  265. def forward(
  266. self,
  267. hidden_states: torch.Tensor,
  268. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  269. attention_mask: torch.Tensor | None = None,
  270. **kwargs: Unpack[TransformersKwargs],
  271. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  272. input_shape = hidden_states.shape[:-1]
  273. qkv = self.Wqkv(hidden_states)
  274. qkv = qkv.view(*input_shape, 3, -1, self.head_dim)
  275. query_states, key_states, value_states = qkv.unbind(dim=-3)
  276. query_states = query_states.transpose(1, 2)
  277. key_states = key_states.transpose(1, 2)
  278. value_states = value_states.transpose(1, 2)
  279. cos, sin = position_embeddings
  280. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=1)
  281. attention_interface = eager_attention_forward
  282. if self.config._attn_implementation != "eager":
  283. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  284. attn_output, attn_weights = attention_interface(
  285. self,
  286. query_states,
  287. key_states,
  288. value_states,
  289. attention_mask,
  290. dropout=self.attention_dropout if self.training else 0.0,
  291. scaling=self.head_dim**-0.5,
  292. sliding_window=self.sliding_window,
  293. deterministic=self.deterministic_flash_attn,
  294. **kwargs,
  295. )
  296. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  297. attn_output = self.out_drop(self.Wo(attn_output))
  298. return attn_output, attn_weights
  299. class ModernBertEncoderLayer(GradientCheckpointingLayer):
  300. def __init__(self, config: ModernBertConfig, layer_idx: int | None = None):
  301. super().__init__()
  302. self.config = config
  303. self.layer_idx = layer_idx
  304. if layer_idx == 0:
  305. self.attn_norm = nn.Identity()
  306. else:
  307. self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
  308. self.attn = ModernBertAttention(config=config, layer_idx=layer_idx)
  309. self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
  310. self.mlp = ModernBertMLP(config)
  311. self.attention_type = config.layer_types[layer_idx]
  312. def forward(
  313. self,
  314. hidden_states: torch.Tensor,
  315. attention_mask: torch.Tensor | None = None,
  316. position_embeddings: torch.Tensor | None = None,
  317. **kwargs: Unpack[TransformersKwargs],
  318. ) -> torch.Tensor:
  319. attn_output, _ = self.attn(
  320. self.attn_norm(hidden_states),
  321. position_embeddings=position_embeddings,
  322. attention_mask=attention_mask,
  323. **kwargs,
  324. )
  325. hidden_states = hidden_states + attn_output
  326. hidden_states = hidden_states + self.mlp(self.mlp_norm(hidden_states))
  327. return hidden_states
  328. @auto_docstring
  329. class ModernBertPreTrainedModel(PreTrainedModel):
  330. config: ModernBertConfig
  331. base_model_prefix = "model"
  332. supports_gradient_checkpointing = True
  333. _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"]
  334. _supports_flash_attn = True
  335. _supports_sdpa = True
  336. _supports_flex_attn = True
  337. _supports_attention_backend = True
  338. _can_record_outputs = {
  339. "hidden_states": ModernBertEncoderLayer,
  340. "attentions": ModernBertAttention,
  341. }
  342. @torch.no_grad()
  343. def _init_weights(self, module: nn.Module):
  344. cutoff_factor = self.config.initializer_cutoff_factor
  345. if cutoff_factor is None:
  346. cutoff_factor = 3
  347. def init_weight(module: nn.Module, std: float):
  348. init.trunc_normal_(
  349. module.weight,
  350. mean=0.0,
  351. std=std,
  352. a=-cutoff_factor * std,
  353. b=cutoff_factor * std,
  354. )
  355. if isinstance(module, nn.Linear):
  356. if module.bias is not None:
  357. init.zeros_(module.bias)
  358. stds = {
  359. "in": self.config.initializer_range,
  360. "out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers),
  361. "embedding": self.config.initializer_range,
  362. "final_out": self.config.hidden_size**-0.5,
  363. }
  364. if isinstance(module, ModernBertEmbeddings):
  365. init_weight(module.tok_embeddings, stds["embedding"])
  366. elif isinstance(module, ModernBertMLP):
  367. init_weight(module.Wi, stds["in"])
  368. init_weight(module.Wo, stds["out"])
  369. elif isinstance(module, ModernBertAttention):
  370. init_weight(module.Wqkv, stds["in"])
  371. init_weight(module.Wo, stds["out"])
  372. elif isinstance(module, ModernBertPredictionHead):
  373. init_weight(module.dense, stds["out"])
  374. elif isinstance(module, ModernBertForMaskedLM):
  375. init_weight(module.decoder, stds["out"])
  376. elif isinstance(
  377. module,
  378. (
  379. ModernBertForSequenceClassification,
  380. ModernBertForMultipleChoice,
  381. ModernBertForTokenClassification,
  382. ModernBertForQuestionAnswering,
  383. ),
  384. ):
  385. init_weight(module.classifier, stds["final_out"])
  386. elif isinstance(module, nn.LayerNorm):
  387. init.ones_(module.weight)
  388. if module.bias is not None:
  389. init.zeros_(module.bias)
  390. elif isinstance(module, ModernBertRotaryEmbedding):
  391. for layer_type in module.layer_types:
  392. rope_init_fn = module.compute_default_rope_parameters
  393. if module.rope_type[layer_type] != "default":
  394. rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
  395. curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
  396. init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
  397. init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
  398. @auto_docstring
  399. class ModernBertModel(ModernBertPreTrainedModel):
  400. def __init__(self, config: ModernBertConfig):
  401. super().__init__(config)
  402. self.config = config
  403. self.embeddings = ModernBertEmbeddings(config)
  404. self.layers = nn.ModuleList(
  405. [ModernBertEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  406. )
  407. self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
  408. self.rotary_emb = ModernBertRotaryEmbedding(config=config)
  409. self.gradient_checkpointing = False
  410. self.post_init()
  411. def get_input_embeddings(self):
  412. return self.embeddings.tok_embeddings
  413. def set_input_embeddings(self, value):
  414. self.embeddings.tok_embeddings = value
  415. @merge_with_config_defaults
  416. @capture_outputs
  417. @auto_docstring
  418. def forward(
  419. self,
  420. input_ids: torch.LongTensor | None = None,
  421. attention_mask: torch.Tensor | None = None,
  422. position_ids: torch.LongTensor | None = None,
  423. inputs_embeds: torch.Tensor | None = None,
  424. **kwargs: Unpack[TransformersKwargs],
  425. ) -> BaseModelOutput:
  426. if (input_ids is None) ^ (inputs_embeds is not None):
  427. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  428. seq_len = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1]
  429. device = input_ids.device if input_ids is not None else inputs_embeds.device
  430. if position_ids is None:
  431. position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
  432. hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
  433. if not isinstance(attention_mask_mapping := attention_mask, dict):
  434. mask_kwargs = {
  435. "config": self.config,
  436. "inputs_embeds": hidden_states,
  437. "attention_mask": attention_mask,
  438. }
  439. attention_mask_mapping = {
  440. "full_attention": create_bidirectional_mask(**mask_kwargs),
  441. "sliding_attention": create_bidirectional_sliding_window_mask(**mask_kwargs),
  442. }
  443. position_embeddings = {}
  444. for layer_type in self.config.layer_types:
  445. position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
  446. for encoder_layer in self.layers:
  447. hidden_states = encoder_layer(
  448. hidden_states,
  449. attention_mask=attention_mask_mapping[encoder_layer.attention_type],
  450. position_embeddings=position_embeddings[encoder_layer.attention_type],
  451. **kwargs,
  452. )
  453. hidden_states = self.final_norm(hidden_states)
  454. return BaseModelOutput(last_hidden_state=hidden_states)
  455. class ModernBertPredictionHead(nn.Module):
  456. def __init__(self, config: ModernBertConfig):
  457. super().__init__()
  458. self.config = config
  459. self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias)
  460. self.act = ACT2FN[config.classifier_activation]
  461. self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
  462. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  463. return self.norm(self.act(self.dense(hidden_states)))
  464. @auto_docstring(
  465. custom_intro="""
  466. The ModernBert Model with a decoder head on top that is used for masked language modeling.
  467. """
  468. )
  469. class ModernBertForMaskedLM(ModernBertPreTrainedModel):
  470. _tied_weights_keys = {"decoder.weight": "model.embeddings.tok_embeddings.weight"}
  471. def __init__(self, config: ModernBertConfig):
  472. super().__init__(config)
  473. self.config = config
  474. self.model = ModernBertModel(config)
  475. self.head = ModernBertPredictionHead(config)
  476. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias)
  477. self.sparse_prediction = self.config.sparse_prediction
  478. self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index
  479. # Initialize weights and apply final processing
  480. self.post_init()
  481. def get_output_embeddings(self):
  482. return self.decoder
  483. def set_output_embeddings(self, new_embeddings: nn.Linear):
  484. self.decoder = new_embeddings
  485. @can_return_tuple
  486. @auto_docstring
  487. def forward(
  488. self,
  489. input_ids: torch.LongTensor | None = None,
  490. attention_mask: torch.Tensor | None = None,
  491. position_ids: torch.Tensor | None = None,
  492. inputs_embeds: torch.Tensor | None = None,
  493. labels: torch.Tensor | None = None,
  494. **kwargs: Unpack[TransformersKwargs],
  495. ) -> tuple[torch.Tensor] | MaskedLMOutput:
  496. outputs = self.model(
  497. input_ids=input_ids,
  498. attention_mask=attention_mask,
  499. position_ids=position_ids,
  500. inputs_embeds=inputs_embeds,
  501. **kwargs,
  502. )
  503. last_hidden_state = outputs[0]
  504. if self.sparse_prediction and labels is not None:
  505. # flatten labels and output first
  506. labels = labels.view(-1)
  507. last_hidden_state = last_hidden_state.view(labels.shape[0], -1)
  508. # then filter out the non-masked tokens
  509. mask_tokens = labels != self.sparse_pred_ignore_index
  510. last_hidden_state = last_hidden_state[mask_tokens]
  511. labels = labels[mask_tokens]
  512. logits = self.decoder(self.head(last_hidden_state))
  513. loss = None
  514. if labels is not None:
  515. loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs)
  516. return MaskedLMOutput(
  517. loss=loss,
  518. logits=logits,
  519. hidden_states=outputs.hidden_states,
  520. attentions=outputs.attentions,
  521. )
  522. @auto_docstring(
  523. custom_intro="""
  524. The ModernBert Model with a sequence classification head on top that performs pooling.
  525. """
  526. )
  527. class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
  528. def __init__(self, config: ModernBertConfig):
  529. super().__init__(config)
  530. self.num_labels = config.num_labels
  531. self.config = config
  532. self.model = ModernBertModel(config)
  533. self.head = ModernBertPredictionHead(config)
  534. self.drop = torch.nn.Dropout(config.classifier_dropout)
  535. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  536. # Initialize weights and apply final processing
  537. self.post_init()
  538. @can_return_tuple
  539. @auto_docstring
  540. def forward(
  541. self,
  542. input_ids: torch.LongTensor | None = None,
  543. attention_mask: torch.Tensor | None = None,
  544. position_ids: torch.Tensor | None = None,
  545. inputs_embeds: torch.Tensor | None = None,
  546. labels: torch.Tensor | None = None,
  547. **kwargs: Unpack[TransformersKwargs],
  548. ) -> tuple[torch.Tensor] | SequenceClassifierOutput:
  549. r"""
  550. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  551. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  552. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  553. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  554. """
  555. outputs = self.model(
  556. input_ids=input_ids,
  557. attention_mask=attention_mask,
  558. position_ids=position_ids,
  559. inputs_embeds=inputs_embeds,
  560. **kwargs,
  561. )
  562. last_hidden_state = outputs[0]
  563. if self.config.classifier_pooling == "cls":
  564. last_hidden_state = last_hidden_state[:, 0]
  565. elif self.config.classifier_pooling == "mean":
  566. if attention_mask is None:
  567. attention_mask = torch.ones(
  568. last_hidden_state.shape[:2], device=last_hidden_state.device, dtype=torch.bool
  569. )
  570. last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
  571. dim=1, keepdim=True
  572. )
  573. pooled_output = self.head(last_hidden_state)
  574. pooled_output = self.drop(pooled_output)
  575. logits = self.classifier(pooled_output)
  576. loss = None
  577. if labels is not None:
  578. if self.config.problem_type is None:
  579. if self.num_labels == 1:
  580. self.config.problem_type = "regression"
  581. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  582. self.config.problem_type = "single_label_classification"
  583. else:
  584. self.config.problem_type = "multi_label_classification"
  585. if self.config.problem_type == "regression":
  586. loss_fct = MSELoss()
  587. if self.num_labels == 1:
  588. loss = loss_fct(logits.squeeze(), labels.squeeze())
  589. else:
  590. loss = loss_fct(logits, labels)
  591. elif self.config.problem_type == "single_label_classification":
  592. loss_fct = CrossEntropyLoss()
  593. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  594. elif self.config.problem_type == "multi_label_classification":
  595. loss_fct = BCEWithLogitsLoss()
  596. loss = loss_fct(logits, labels)
  597. return SequenceClassifierOutput(
  598. loss=loss,
  599. logits=logits,
  600. hidden_states=outputs.hidden_states,
  601. attentions=outputs.attentions,
  602. )
  603. @auto_docstring(
  604. custom_intro="""
  605. The ModernBert Model with a token classification head on top, e.g. for Named Entity Recognition (NER) tasks.
  606. """
  607. )
  608. class ModernBertForTokenClassification(ModernBertPreTrainedModel):
  609. def __init__(self, config: ModernBertConfig):
  610. super().__init__(config)
  611. self.num_labels = config.num_labels
  612. self.model = ModernBertModel(config)
  613. self.head = ModernBertPredictionHead(config)
  614. self.drop = torch.nn.Dropout(config.classifier_dropout)
  615. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  616. # Initialize weights and apply final processing
  617. self.post_init()
  618. @can_return_tuple
  619. @auto_docstring
  620. def forward(
  621. self,
  622. input_ids: torch.LongTensor | None = None,
  623. attention_mask: torch.Tensor | None = None,
  624. position_ids: torch.Tensor | None = None,
  625. inputs_embeds: torch.Tensor | None = None,
  626. labels: torch.Tensor | None = None,
  627. **kwargs: Unpack[TransformersKwargs],
  628. ) -> tuple[torch.Tensor] | TokenClassifierOutput:
  629. r"""
  630. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  631. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  632. """
  633. outputs = self.model(
  634. input_ids=input_ids,
  635. attention_mask=attention_mask,
  636. position_ids=position_ids,
  637. inputs_embeds=inputs_embeds,
  638. **kwargs,
  639. )
  640. last_hidden_state = outputs[0]
  641. last_hidden_state = self.head(last_hidden_state)
  642. last_hidden_state = self.drop(last_hidden_state)
  643. logits = self.classifier(last_hidden_state)
  644. loss = None
  645. if labels is not None:
  646. loss_fct = CrossEntropyLoss()
  647. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  648. return TokenClassifierOutput(
  649. loss=loss,
  650. logits=logits,
  651. hidden_states=outputs.hidden_states,
  652. attentions=outputs.attentions,
  653. )
  654. @auto_docstring
  655. class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
  656. def __init__(self, config: ModernBertConfig):
  657. super().__init__(config)
  658. self.num_labels = config.num_labels
  659. self.model = ModernBertModel(config)
  660. self.head = ModernBertPredictionHead(config)
  661. self.drop = torch.nn.Dropout(config.classifier_dropout)
  662. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  663. self.post_init()
  664. @can_return_tuple
  665. @auto_docstring
  666. def forward(
  667. self,
  668. input_ids: torch.Tensor | None = None,
  669. attention_mask: torch.Tensor | None = None,
  670. position_ids: torch.Tensor | None = None,
  671. start_positions: torch.Tensor | None = None,
  672. end_positions: torch.Tensor | None = None,
  673. **kwargs: Unpack[TransformersKwargs],
  674. ) -> tuple[torch.Tensor] | QuestionAnsweringModelOutput:
  675. outputs = self.model(
  676. input_ids,
  677. attention_mask=attention_mask,
  678. position_ids=position_ids,
  679. **kwargs,
  680. )
  681. last_hidden_state = outputs[0]
  682. last_hidden_state = self.head(last_hidden_state)
  683. last_hidden_state = self.drop(last_hidden_state)
  684. logits = self.classifier(last_hidden_state)
  685. start_logits, end_logits = logits.split(1, dim=-1)
  686. start_logits = start_logits.squeeze(-1).contiguous()
  687. end_logits = end_logits.squeeze(-1).contiguous()
  688. loss = None
  689. if start_positions is not None and end_positions is not None:
  690. loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
  691. return QuestionAnsweringModelOutput(
  692. loss=loss,
  693. start_logits=start_logits,
  694. end_logits=end_logits,
  695. hidden_states=outputs.hidden_states,
  696. attentions=outputs.attentions,
  697. )
  698. @auto_docstring(
  699. custom_intro="""
  700. The ModernBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks.
  701. """
  702. )
  703. class ModernBertForMultipleChoice(ModernBertPreTrainedModel):
  704. def __init__(self, config: ModernBertConfig):
  705. super().__init__(config)
  706. self.config = config
  707. self.model = ModernBertModel(config)
  708. self.head = ModernBertPredictionHead(config)
  709. self.drop = torch.nn.Dropout(config.classifier_dropout)
  710. self.classifier = nn.Linear(config.hidden_size, 1)
  711. # Initialize weights and apply final processing
  712. self.post_init()
  713. @can_return_tuple
  714. @auto_docstring
  715. def forward(
  716. self,
  717. input_ids: torch.LongTensor | None = None,
  718. attention_mask: torch.Tensor | None = None,
  719. position_ids: torch.Tensor | None = None,
  720. inputs_embeds: torch.Tensor | None = None,
  721. labels: torch.Tensor | None = None,
  722. **kwargs: Unpack[TransformersKwargs],
  723. ) -> tuple[torch.Tensor] | MultipleChoiceModelOutput:
  724. r"""
  725. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  726. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  727. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors.
  728. """
  729. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  730. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  731. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  732. position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  733. inputs_embeds = (
  734. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  735. if inputs_embeds is not None
  736. else None
  737. )
  738. outputs = self.model(
  739. input_ids=input_ids,
  740. attention_mask=attention_mask,
  741. position_ids=position_ids,
  742. inputs_embeds=inputs_embeds,
  743. **kwargs,
  744. )
  745. last_hidden_state = outputs[0] # shape (num_choices, seq_len, hidden_size)
  746. # If classifier_pooling is "cls", isolate the <cls> token
  747. if self.config.classifier_pooling == "cls":
  748. indices_0 = torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device)
  749. # for left or right padding, <cls> is the first non-pad token
  750. if attention_mask is not None:
  751. cls_mask = attention_mask.argmax(dim=-1).to(last_hidden_state.device)
  752. # if no pad, <cls> is the first token
  753. else:
  754. cls_mask = torch.tensor(0, dtype=torch.long, device=last_hidden_state.device)
  755. # extract the <cls> token for the logits
  756. last_hidden_state = last_hidden_state[indices_0, cls_mask]
  757. # If classifier_pooling is "mean", pool the hidden states by averaging over the sequence length
  758. elif self.config.classifier_pooling == "mean":
  759. num_non_pad_tokens = attention_mask.sum(dim=1, keepdim=True)
  760. last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / num_non_pad_tokens
  761. pooled_output = self.head(last_hidden_state)
  762. pooled_output = self.drop(pooled_output)
  763. logits = self.classifier(pooled_output)
  764. reshaped_logits = logits.view(-1, num_choices)
  765. loss = None
  766. if labels is not None:
  767. loss_fct = nn.CrossEntropyLoss()
  768. loss = loss_fct(reshaped_logits, labels)
  769. return MultipleChoiceModelOutput(
  770. loss=loss,
  771. logits=reshaped_logits,
  772. hidden_states=outputs.hidden_states,
  773. attentions=outputs.attentions,
  774. )
  775. __all__ = [
  776. "ModernBertConfig",
  777. "ModernBertModel",
  778. "ModernBertPreTrainedModel",
  779. "ModernBertForMaskedLM",
  780. "ModernBertForSequenceClassification",
  781. "ModernBertForTokenClassification",
  782. "ModernBertForQuestionAnswering",
  783. "ModernBertForMultipleChoice",
  784. ]