modeling_modernbert.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/modernbert/modular_modernbert.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_modernbert.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2024 Answer.AI, LightOn, and contributors, and the HuggingFace Inc. team. All rights reserved.
  8. #
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. import math
  22. from collections.abc import Callable
  23. from typing import Optional
  24. import torch
  25. from torch import nn
  26. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  27. from ... import initialization as init
  28. from ...activations import ACT2FN
  29. from ...integrations import use_kernel_func_from_hub, use_kernelized_func
  30. from ...masking_utils import create_bidirectional_mask, create_bidirectional_sliding_window_mask
  31. from ...modeling_layers import GradientCheckpointingLayer
  32. from ...modeling_outputs import (
  33. BaseModelOutput,
  34. MaskedLMOutput,
  35. MultipleChoiceModelOutput,
  36. QuestionAnsweringModelOutput,
  37. SequenceClassifierOutput,
  38. TokenClassifierOutput,
  39. )
  40. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  41. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  42. from ...processing_utils import Unpack
  43. from ...utils import TransformersKwargs, auto_docstring
  44. from ...utils.generic import can_return_tuple, maybe_autocast, merge_with_config_defaults
  45. from ...utils.output_capturing import capture_outputs
  46. from .configuration_modernbert import ModernBertConfig
  47. class ModernBertEmbeddings(nn.Module):
  48. """
  49. Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
  50. """
  51. def __init__(self, config: ModernBertConfig):
  52. super().__init__()
  53. self.config = config
  54. self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  55. self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
  56. self.drop = nn.Dropout(config.embedding_dropout)
  57. def forward(
  58. self, input_ids: torch.LongTensor | None = None, inputs_embeds: torch.Tensor | None = None
  59. ) -> torch.Tensor:
  60. if inputs_embeds is not None:
  61. hidden_states = self.drop(self.norm(inputs_embeds))
  62. else:
  63. hidden_states = self.drop(self.norm(self.tok_embeddings(input_ids)))
  64. return hidden_states
  65. class ModernBertMLP(nn.Module):
  66. """Applies the GLU at the end of each ModernBERT layer.
  67. Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
  68. and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality.
  69. """
  70. def __init__(self, config: ModernBertConfig):
  71. super().__init__()
  72. self.config = config
  73. self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias)
  74. self.act = ACT2FN[config.hidden_activation]
  75. self.drop = nn.Dropout(config.mlp_dropout)
  76. self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias)
  77. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  78. input, gate = self.Wi(hidden_states).chunk(2, dim=-1)
  79. return self.Wo(self.drop(self.act(input) * gate))
  80. class ModernBertRotaryEmbedding(nn.Module):
  81. inv_freq: torch.Tensor # fix linting for `register_buffer`
  82. def __init__(self, config: ModernBertConfig, device=None):
  83. super().__init__()
  84. self.max_seq_len_cached = config.max_position_embeddings
  85. self.original_max_seq_len = config.max_position_embeddings
  86. self.config = config
  87. self.layer_types = list(set(config.layer_types))
  88. self.rope_type = {}
  89. for layer_type in self.layer_types:
  90. rope_params = self.config.rope_parameters[layer_type]
  91. if rope_params is None:
  92. continue
  93. self.rope_type[layer_type] = rope_params["rope_type"]
  94. rope_init_fn: Callable = self.compute_default_rope_parameters
  95. if self.rope_type[layer_type] != "default":
  96. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]]
  97. curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, device, layer_type=layer_type)
  98. self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
  99. self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
  100. setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
  101. @staticmethod
  102. def compute_default_rope_parameters(
  103. config: ModernBertConfig | None = None,
  104. device: Optional["torch.device"] = None,
  105. seq_len: int | None = None,
  106. layer_type: str | None = None,
  107. ) -> tuple["torch.Tensor", float]:
  108. """
  109. Computes the inverse frequencies according to the original RoPE implementation
  110. Args:
  111. config ([`~transformers.PreTrainedConfig`]):
  112. The model configuration.
  113. device (`torch.device`):
  114. The device to use for initialization of the inverse frequencies.
  115. seq_len (`int`, *optional*):
  116. The current sequence length. Unused for this type of RoPE.
  117. layer_type (`str`, *optional*):
  118. The current layer type if the model has different RoPE parameters per type.
  119. Should not be used unless `config.layer_types is not None`
  120. Returns:
  121. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  122. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  123. """
  124. # For backward compatibility standardize the `rope_parameters_dict` if it uses old format
  125. base = config.rope_parameters[layer_type]["rope_theta"]
  126. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  127. attention_factor = 1.0 # Unused in this type of RoPE
  128. # Compute the inverse frequencies
  129. inv_freq = 1.0 / (
  130. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  131. )
  132. return inv_freq, attention_factor
  133. @torch.no_grad()
  134. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  135. def forward(self, x, position_ids, layer_type=None):
  136. inv_freq = getattr(self, f"{layer_type}_inv_freq")
  137. attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
  138. inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  139. position_ids_expanded = position_ids[:, None, :].float()
  140. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  141. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  142. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  143. emb = torch.cat((freqs, freqs), dim=-1)
  144. cos = emb.cos() * attention_scaling
  145. sin = emb.sin() * attention_scaling
  146. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  147. def eager_attention_forward(
  148. module: nn.Module,
  149. query: torch.Tensor,
  150. key: torch.Tensor,
  151. value: torch.Tensor,
  152. attention_mask: torch.Tensor | None,
  153. scaling: float,
  154. dropout: float = 0.0,
  155. **kwargs,
  156. ):
  157. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  158. if attention_mask is not None:
  159. attn_weights = attn_weights + attention_mask
  160. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  161. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  162. attn_output = torch.matmul(attn_weights, value)
  163. attn_output = attn_output.transpose(1, 2).contiguous()
  164. return attn_output, attn_weights
  165. def rotate_half(x):
  166. """Rotates half the hidden dims of the input."""
  167. x1 = x[..., : x.shape[-1] // 2]
  168. x2 = x[..., x.shape[-1] // 2 :]
  169. return torch.cat((-x2, x1), dim=-1)
  170. @use_kernel_func_from_hub("rotary_pos_emb")
  171. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  172. """Applies Rotary Position Embedding to the query and key tensors.
  173. Args:
  174. q (`torch.Tensor`): The query tensor.
  175. k (`torch.Tensor`): The key tensor.
  176. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  177. sin (`torch.Tensor`): The sine part of the rotary embedding.
  178. unsqueeze_dim (`int`, *optional*, defaults to 1):
  179. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  180. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  181. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  182. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  183. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  184. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  185. Returns:
  186. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  187. """
  188. original_dtype = q.dtype
  189. cos = cos.unsqueeze(unsqueeze_dim)
  190. sin = sin.unsqueeze(unsqueeze_dim)
  191. q_embed = (q.float() * cos) + (rotate_half(q.float()) * sin)
  192. k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin)
  193. return q_embed.to(original_dtype), k_embed.to(original_dtype)
  194. @use_kernelized_func(apply_rotary_pos_emb)
  195. class ModernBertAttention(nn.Module):
  196. """Performs multi-headed self attention on a batch of unpadded sequences.
  197. If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput.
  198. If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel,
  199. which requires padding and unpadding inputs, adding some overhead.
  200. See `forward` method for additional details.
  201. """
  202. def __init__(self, config: ModernBertConfig, layer_idx: int | None = None):
  203. super().__init__()
  204. self.config = config
  205. self.layer_idx = layer_idx
  206. if config.hidden_size % config.num_attention_heads != 0:
  207. raise ValueError(
  208. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})"
  209. )
  210. self.attention_dropout = config.attention_dropout
  211. self.deterministic_flash_attn = config.deterministic_flash_attn
  212. self.head_dim = config.hidden_size // config.num_attention_heads
  213. self.Wqkv = nn.Linear(
  214. config.hidden_size, 3 * self.head_dim * config.num_attention_heads, bias=config.attention_bias
  215. )
  216. if config.layer_types[layer_idx] == "sliding_attention":
  217. # config.sliding_window = local_attention // 2 (half-window size, e.g. 64 for local_attention=128)
  218. # +1 is needed because flash attention sets inclusive boundaries (see modeling_flash_attention_utils.py)
  219. self.sliding_window = config.sliding_window + 1
  220. else:
  221. self.sliding_window = None
  222. self.is_causal = False
  223. self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
  224. self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
  225. def forward(
  226. self,
  227. hidden_states: torch.Tensor,
  228. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  229. attention_mask: torch.Tensor | None = None,
  230. **kwargs: Unpack[TransformersKwargs],
  231. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  232. input_shape = hidden_states.shape[:-1]
  233. qkv = self.Wqkv(hidden_states)
  234. qkv = qkv.view(*input_shape, 3, -1, self.head_dim)
  235. query_states, key_states, value_states = qkv.unbind(dim=-3)
  236. query_states = query_states.transpose(1, 2)
  237. key_states = key_states.transpose(1, 2)
  238. value_states = value_states.transpose(1, 2)
  239. cos, sin = position_embeddings
  240. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=1)
  241. attention_interface = eager_attention_forward
  242. if self.config._attn_implementation != "eager":
  243. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  244. attn_output, attn_weights = attention_interface(
  245. self,
  246. query_states,
  247. key_states,
  248. value_states,
  249. attention_mask,
  250. dropout=self.attention_dropout if self.training else 0.0,
  251. scaling=self.head_dim**-0.5,
  252. sliding_window=self.sliding_window,
  253. deterministic=self.deterministic_flash_attn,
  254. **kwargs,
  255. )
  256. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  257. attn_output = self.out_drop(self.Wo(attn_output))
  258. return attn_output, attn_weights
  259. class ModernBertEncoderLayer(GradientCheckpointingLayer):
  260. def __init__(self, config: ModernBertConfig, layer_idx: int | None = None):
  261. super().__init__()
  262. self.config = config
  263. self.layer_idx = layer_idx
  264. if layer_idx == 0:
  265. self.attn_norm = nn.Identity()
  266. else:
  267. self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
  268. self.attn = ModernBertAttention(config=config, layer_idx=layer_idx)
  269. self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
  270. self.mlp = ModernBertMLP(config)
  271. self.attention_type = config.layer_types[layer_idx]
  272. def forward(
  273. self,
  274. hidden_states: torch.Tensor,
  275. attention_mask: torch.Tensor | None = None,
  276. position_embeddings: torch.Tensor | None = None,
  277. **kwargs: Unpack[TransformersKwargs],
  278. ) -> torch.Tensor:
  279. attn_output, _ = self.attn(
  280. self.attn_norm(hidden_states),
  281. position_embeddings=position_embeddings,
  282. attention_mask=attention_mask,
  283. **kwargs,
  284. )
  285. hidden_states = hidden_states + attn_output
  286. hidden_states = hidden_states + self.mlp(self.mlp_norm(hidden_states))
  287. return hidden_states
  288. @auto_docstring
  289. class ModernBertPreTrainedModel(PreTrainedModel):
  290. config: ModernBertConfig
  291. base_model_prefix = "model"
  292. supports_gradient_checkpointing = True
  293. _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"]
  294. _supports_flash_attn = True
  295. _supports_sdpa = True
  296. _supports_flex_attn = True
  297. _supports_attention_backend = True
  298. _can_record_outputs = {
  299. "hidden_states": ModernBertEncoderLayer,
  300. "attentions": ModernBertAttention,
  301. }
  302. @torch.no_grad()
  303. def _init_weights(self, module: nn.Module):
  304. cutoff_factor = self.config.initializer_cutoff_factor
  305. if cutoff_factor is None:
  306. cutoff_factor = 3
  307. def init_weight(module: nn.Module, std: float):
  308. init.trunc_normal_(
  309. module.weight,
  310. mean=0.0,
  311. std=std,
  312. a=-cutoff_factor * std,
  313. b=cutoff_factor * std,
  314. )
  315. if isinstance(module, nn.Linear):
  316. if module.bias is not None:
  317. init.zeros_(module.bias)
  318. stds = {
  319. "in": self.config.initializer_range,
  320. "out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers),
  321. "embedding": self.config.initializer_range,
  322. "final_out": self.config.hidden_size**-0.5,
  323. }
  324. if isinstance(module, ModernBertEmbeddings):
  325. init_weight(module.tok_embeddings, stds["embedding"])
  326. elif isinstance(module, ModernBertMLP):
  327. init_weight(module.Wi, stds["in"])
  328. init_weight(module.Wo, stds["out"])
  329. elif isinstance(module, ModernBertAttention):
  330. init_weight(module.Wqkv, stds["in"])
  331. init_weight(module.Wo, stds["out"])
  332. elif isinstance(module, ModernBertPredictionHead):
  333. init_weight(module.dense, stds["out"])
  334. elif isinstance(module, ModernBertForMaskedLM):
  335. init_weight(module.decoder, stds["out"])
  336. elif isinstance(
  337. module,
  338. (
  339. ModernBertForSequenceClassification,
  340. ModernBertForMultipleChoice,
  341. ModernBertForTokenClassification,
  342. ModernBertForQuestionAnswering,
  343. ),
  344. ):
  345. init_weight(module.classifier, stds["final_out"])
  346. elif isinstance(module, nn.LayerNorm):
  347. init.ones_(module.weight)
  348. if module.bias is not None:
  349. init.zeros_(module.bias)
  350. elif isinstance(module, ModernBertRotaryEmbedding):
  351. for layer_type in module.layer_types:
  352. rope_init_fn = module.compute_default_rope_parameters
  353. if module.rope_type[layer_type] != "default":
  354. rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
  355. curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
  356. init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
  357. init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
  358. @auto_docstring
  359. class ModernBertModel(ModernBertPreTrainedModel):
  360. def __init__(self, config: ModernBertConfig):
  361. super().__init__(config)
  362. self.config = config
  363. self.embeddings = ModernBertEmbeddings(config)
  364. self.layers = nn.ModuleList(
  365. [ModernBertEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  366. )
  367. self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
  368. self.rotary_emb = ModernBertRotaryEmbedding(config=config)
  369. self.gradient_checkpointing = False
  370. self.post_init()
  371. def get_input_embeddings(self):
  372. return self.embeddings.tok_embeddings
  373. def set_input_embeddings(self, value):
  374. self.embeddings.tok_embeddings = value
  375. @merge_with_config_defaults
  376. @capture_outputs
  377. @auto_docstring
  378. def forward(
  379. self,
  380. input_ids: torch.LongTensor | None = None,
  381. attention_mask: torch.Tensor | None = None,
  382. position_ids: torch.LongTensor | None = None,
  383. inputs_embeds: torch.Tensor | None = None,
  384. **kwargs: Unpack[TransformersKwargs],
  385. ) -> BaseModelOutput:
  386. if (input_ids is None) ^ (inputs_embeds is not None):
  387. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  388. seq_len = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1]
  389. device = input_ids.device if input_ids is not None else inputs_embeds.device
  390. if position_ids is None:
  391. position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
  392. hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
  393. if not isinstance(attention_mask_mapping := attention_mask, dict):
  394. mask_kwargs = {
  395. "config": self.config,
  396. "inputs_embeds": hidden_states,
  397. "attention_mask": attention_mask,
  398. }
  399. attention_mask_mapping = {
  400. "full_attention": create_bidirectional_mask(**mask_kwargs),
  401. "sliding_attention": create_bidirectional_sliding_window_mask(**mask_kwargs),
  402. }
  403. position_embeddings = {}
  404. for layer_type in self.config.layer_types:
  405. position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
  406. for encoder_layer in self.layers:
  407. hidden_states = encoder_layer(
  408. hidden_states,
  409. attention_mask=attention_mask_mapping[encoder_layer.attention_type],
  410. position_embeddings=position_embeddings[encoder_layer.attention_type],
  411. **kwargs,
  412. )
  413. hidden_states = self.final_norm(hidden_states)
  414. return BaseModelOutput(last_hidden_state=hidden_states)
  415. class ModernBertPredictionHead(nn.Module):
  416. def __init__(self, config: ModernBertConfig):
  417. super().__init__()
  418. self.config = config
  419. self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias)
  420. self.act = ACT2FN[config.classifier_activation]
  421. self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
  422. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  423. return self.norm(self.act(self.dense(hidden_states)))
  424. @auto_docstring(
  425. custom_intro="""
  426. The ModernBert Model with a decoder head on top that is used for masked language modeling.
  427. """
  428. )
  429. class ModernBertForMaskedLM(ModernBertPreTrainedModel):
  430. _tied_weights_keys = {"decoder.weight": "model.embeddings.tok_embeddings.weight"}
  431. def __init__(self, config: ModernBertConfig):
  432. super().__init__(config)
  433. self.config = config
  434. self.model = ModernBertModel(config)
  435. self.head = ModernBertPredictionHead(config)
  436. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias)
  437. self.sparse_prediction = self.config.sparse_prediction
  438. self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index
  439. # Initialize weights and apply final processing
  440. self.post_init()
  441. def get_output_embeddings(self):
  442. return self.decoder
  443. def set_output_embeddings(self, new_embeddings: nn.Linear):
  444. self.decoder = new_embeddings
  445. @can_return_tuple
  446. @auto_docstring
  447. def forward(
  448. self,
  449. input_ids: torch.LongTensor | None = None,
  450. attention_mask: torch.Tensor | None = None,
  451. position_ids: torch.Tensor | None = None,
  452. inputs_embeds: torch.Tensor | None = None,
  453. labels: torch.Tensor | None = None,
  454. **kwargs: Unpack[TransformersKwargs],
  455. ) -> tuple[torch.Tensor] | MaskedLMOutput:
  456. outputs = self.model(
  457. input_ids=input_ids,
  458. attention_mask=attention_mask,
  459. position_ids=position_ids,
  460. inputs_embeds=inputs_embeds,
  461. **kwargs,
  462. )
  463. last_hidden_state = outputs[0]
  464. if self.sparse_prediction and labels is not None:
  465. # flatten labels and output first
  466. labels = labels.view(-1)
  467. last_hidden_state = last_hidden_state.view(labels.shape[0], -1)
  468. # then filter out the non-masked tokens
  469. mask_tokens = labels != self.sparse_pred_ignore_index
  470. last_hidden_state = last_hidden_state[mask_tokens]
  471. labels = labels[mask_tokens]
  472. logits = self.decoder(self.head(last_hidden_state))
  473. loss = None
  474. if labels is not None:
  475. loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs)
  476. return MaskedLMOutput(
  477. loss=loss,
  478. logits=logits,
  479. hidden_states=outputs.hidden_states,
  480. attentions=outputs.attentions,
  481. )
  482. @auto_docstring(
  483. custom_intro="""
  484. The ModernBert Model with a sequence classification head on top that performs pooling.
  485. """
  486. )
  487. class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
  488. def __init__(self, config: ModernBertConfig):
  489. super().__init__(config)
  490. self.num_labels = config.num_labels
  491. self.config = config
  492. self.model = ModernBertModel(config)
  493. self.head = ModernBertPredictionHead(config)
  494. self.drop = torch.nn.Dropout(config.classifier_dropout)
  495. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  496. # Initialize weights and apply final processing
  497. self.post_init()
  498. @can_return_tuple
  499. @auto_docstring
  500. def forward(
  501. self,
  502. input_ids: torch.LongTensor | None = None,
  503. attention_mask: torch.Tensor | None = None,
  504. position_ids: torch.Tensor | None = None,
  505. inputs_embeds: torch.Tensor | None = None,
  506. labels: torch.Tensor | None = None,
  507. **kwargs: Unpack[TransformersKwargs],
  508. ) -> tuple[torch.Tensor] | SequenceClassifierOutput:
  509. r"""
  510. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  511. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  512. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  513. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  514. """
  515. outputs = self.model(
  516. input_ids=input_ids,
  517. attention_mask=attention_mask,
  518. position_ids=position_ids,
  519. inputs_embeds=inputs_embeds,
  520. **kwargs,
  521. )
  522. last_hidden_state = outputs[0]
  523. if self.config.classifier_pooling == "cls":
  524. last_hidden_state = last_hidden_state[:, 0]
  525. elif self.config.classifier_pooling == "mean":
  526. if attention_mask is None:
  527. attention_mask = torch.ones(
  528. last_hidden_state.shape[:2], device=last_hidden_state.device, dtype=torch.bool
  529. )
  530. last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
  531. dim=1, keepdim=True
  532. )
  533. pooled_output = self.head(last_hidden_state)
  534. pooled_output = self.drop(pooled_output)
  535. logits = self.classifier(pooled_output)
  536. loss = None
  537. if labels is not None:
  538. if self.config.problem_type is None:
  539. if self.num_labels == 1:
  540. self.config.problem_type = "regression"
  541. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  542. self.config.problem_type = "single_label_classification"
  543. else:
  544. self.config.problem_type = "multi_label_classification"
  545. if self.config.problem_type == "regression":
  546. loss_fct = MSELoss()
  547. if self.num_labels == 1:
  548. loss = loss_fct(logits.squeeze(), labels.squeeze())
  549. else:
  550. loss = loss_fct(logits, labels)
  551. elif self.config.problem_type == "single_label_classification":
  552. loss_fct = CrossEntropyLoss()
  553. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  554. elif self.config.problem_type == "multi_label_classification":
  555. loss_fct = BCEWithLogitsLoss()
  556. loss = loss_fct(logits, labels)
  557. return SequenceClassifierOutput(
  558. loss=loss,
  559. logits=logits,
  560. hidden_states=outputs.hidden_states,
  561. attentions=outputs.attentions,
  562. )
  563. @auto_docstring(
  564. custom_intro="""
  565. The ModernBert Model with a token classification head on top, e.g. for Named Entity Recognition (NER) tasks.
  566. """
  567. )
  568. class ModernBertForTokenClassification(ModernBertPreTrainedModel):
  569. def __init__(self, config: ModernBertConfig):
  570. super().__init__(config)
  571. self.num_labels = config.num_labels
  572. self.model = ModernBertModel(config)
  573. self.head = ModernBertPredictionHead(config)
  574. self.drop = torch.nn.Dropout(config.classifier_dropout)
  575. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  576. # Initialize weights and apply final processing
  577. self.post_init()
  578. @can_return_tuple
  579. @auto_docstring
  580. def forward(
  581. self,
  582. input_ids: torch.LongTensor | None = None,
  583. attention_mask: torch.Tensor | None = None,
  584. position_ids: torch.Tensor | None = None,
  585. inputs_embeds: torch.Tensor | None = None,
  586. labels: torch.Tensor | None = None,
  587. **kwargs: Unpack[TransformersKwargs],
  588. ) -> tuple[torch.Tensor] | TokenClassifierOutput:
  589. r"""
  590. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  591. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  592. """
  593. outputs = self.model(
  594. input_ids=input_ids,
  595. attention_mask=attention_mask,
  596. position_ids=position_ids,
  597. inputs_embeds=inputs_embeds,
  598. **kwargs,
  599. )
  600. last_hidden_state = outputs[0]
  601. last_hidden_state = self.head(last_hidden_state)
  602. last_hidden_state = self.drop(last_hidden_state)
  603. logits = self.classifier(last_hidden_state)
  604. loss = None
  605. if labels is not None:
  606. loss_fct = CrossEntropyLoss()
  607. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  608. return TokenClassifierOutput(
  609. loss=loss,
  610. logits=logits,
  611. hidden_states=outputs.hidden_states,
  612. attentions=outputs.attentions,
  613. )
  614. @auto_docstring
  615. class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
  616. def __init__(self, config: ModernBertConfig):
  617. super().__init__(config)
  618. self.num_labels = config.num_labels
  619. self.model = ModernBertModel(config)
  620. self.head = ModernBertPredictionHead(config)
  621. self.drop = torch.nn.Dropout(config.classifier_dropout)
  622. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  623. self.post_init()
  624. @can_return_tuple
  625. @auto_docstring
  626. def forward(
  627. self,
  628. input_ids: torch.Tensor | None = None,
  629. attention_mask: torch.Tensor | None = None,
  630. position_ids: torch.Tensor | None = None,
  631. start_positions: torch.Tensor | None = None,
  632. end_positions: torch.Tensor | None = None,
  633. **kwargs: Unpack[TransformersKwargs],
  634. ) -> tuple[torch.Tensor] | QuestionAnsweringModelOutput:
  635. outputs = self.model(
  636. input_ids,
  637. attention_mask=attention_mask,
  638. position_ids=position_ids,
  639. **kwargs,
  640. )
  641. last_hidden_state = outputs[0]
  642. last_hidden_state = self.head(last_hidden_state)
  643. last_hidden_state = self.drop(last_hidden_state)
  644. logits = self.classifier(last_hidden_state)
  645. start_logits, end_logits = logits.split(1, dim=-1)
  646. start_logits = start_logits.squeeze(-1).contiguous()
  647. end_logits = end_logits.squeeze(-1).contiguous()
  648. loss = None
  649. if start_positions is not None and end_positions is not None:
  650. loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
  651. return QuestionAnsweringModelOutput(
  652. loss=loss,
  653. start_logits=start_logits,
  654. end_logits=end_logits,
  655. hidden_states=outputs.hidden_states,
  656. attentions=outputs.attentions,
  657. )
  658. @auto_docstring(
  659. custom_intro="""
  660. The ModernBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks.
  661. """
  662. )
  663. class ModernBertForMultipleChoice(ModernBertPreTrainedModel):
  664. def __init__(self, config: ModernBertConfig):
  665. super().__init__(config)
  666. self.config = config
  667. self.model = ModernBertModel(config)
  668. self.head = ModernBertPredictionHead(config)
  669. self.drop = torch.nn.Dropout(config.classifier_dropout)
  670. self.classifier = nn.Linear(config.hidden_size, 1)
  671. # Initialize weights and apply final processing
  672. self.post_init()
  673. @can_return_tuple
  674. @auto_docstring
  675. def forward(
  676. self,
  677. input_ids: torch.LongTensor | None = None,
  678. attention_mask: torch.Tensor | None = None,
  679. position_ids: torch.Tensor | None = None,
  680. inputs_embeds: torch.Tensor | None = None,
  681. labels: torch.Tensor | None = None,
  682. **kwargs: Unpack[TransformersKwargs],
  683. ) -> tuple[torch.Tensor] | MultipleChoiceModelOutput:
  684. r"""
  685. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  686. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  687. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors.
  688. """
  689. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  690. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  691. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  692. position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  693. inputs_embeds = (
  694. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  695. if inputs_embeds is not None
  696. else None
  697. )
  698. outputs = self.model(
  699. input_ids=input_ids,
  700. attention_mask=attention_mask,
  701. position_ids=position_ids,
  702. inputs_embeds=inputs_embeds,
  703. **kwargs,
  704. )
  705. last_hidden_state = outputs[0] # shape (num_choices, seq_len, hidden_size)
  706. # If classifier_pooling is "cls", isolate the <cls> token
  707. if self.config.classifier_pooling == "cls":
  708. indices_0 = torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device)
  709. # for left or right padding, <cls> is the first non-pad token
  710. if attention_mask is not None:
  711. cls_mask = attention_mask.argmax(dim=-1).to(last_hidden_state.device)
  712. # if no pad, <cls> is the first token
  713. else:
  714. cls_mask = torch.tensor(0, dtype=torch.long, device=last_hidden_state.device)
  715. # extract the <cls> token for the logits
  716. last_hidden_state = last_hidden_state[indices_0, cls_mask]
  717. # If classifier_pooling is "mean", pool the hidden states by averaging over the sequence length
  718. elif self.config.classifier_pooling == "mean":
  719. num_non_pad_tokens = attention_mask.sum(dim=1, keepdim=True)
  720. last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / num_non_pad_tokens
  721. pooled_output = self.head(last_hidden_state)
  722. pooled_output = self.drop(pooled_output)
  723. logits = self.classifier(pooled_output)
  724. reshaped_logits = logits.view(-1, num_choices)
  725. loss = None
  726. if labels is not None:
  727. loss_fct = nn.CrossEntropyLoss()
  728. loss = loss_fct(reshaped_logits, labels)
  729. return MultipleChoiceModelOutput(
  730. loss=loss,
  731. logits=reshaped_logits,
  732. hidden_states=outputs.hidden_states,
  733. attentions=outputs.attentions,
  734. )
  735. __all__ = [
  736. "ModernBertModel",
  737. "ModernBertPreTrainedModel",
  738. "ModernBertForMaskedLM",
  739. "ModernBertForSequenceClassification",
  740. "ModernBertForTokenClassification",
  741. "ModernBertForQuestionAnswering",
  742. "ModernBertForMultipleChoice",
  743. ]