modular_jais2.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # Copyright 2025 the HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch.nn as nn
  15. from huggingface_hub.dataclasses import strict
  16. from ...utils import auto_docstring, can_return_tuple
  17. from ..llama.configuration_llama import LlamaConfig
  18. from ..llama.modeling_llama import (
  19. LlamaDecoderLayer,
  20. LlamaForCausalLM,
  21. LlamaModel,
  22. LlamaPreTrainedModel,
  23. )
  24. from ..nemotron.modeling_nemotron import NemotronMLP
  25. @auto_docstring(checkpoint="inceptionai/Jais-2-8B-Chat")
  26. @strict
  27. class Jais2Config(LlamaConfig):
  28. base_model_tp_plan = {
  29. "layers.*.self_attn.q_proj": "colwise",
  30. "layers.*.self_attn.k_proj": "colwise",
  31. "layers.*.self_attn.v_proj": "colwise",
  32. "layers.*.self_attn.o_proj": "rowwise",
  33. "layers.*.mlp.up_proj": "colwise",
  34. "layers.*.mlp.down_proj": "rowwise",
  35. }
  36. vocab_size: int = 150272
  37. hidden_size: int = 3328
  38. intermediate_size: int = 26624
  39. num_attention_heads: int = 26
  40. hidden_act: str = "relu2"
  41. max_position_embeddings: int = 8192
  42. layer_norm_eps: float = 1e-5
  43. bos_token_id: int | None = 0
  44. eos_token_id: int | list[int] | None = 150024
  45. attention_bias: bool = True
  46. mlp_bias: bool = True
  47. rms_norm_eps = AttributeError()
  48. pretraining_tp = AttributeError()
  49. class Jais2MLP(NemotronMLP):
  50. pass
  51. class Jais2DecoderLayer(LlamaDecoderLayer):
  52. def __init__(self, config: Jais2Config, layer_idx: int):
  53. super().__init__(config, layer_idx)
  54. self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  55. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  56. class Jais2PreTrainedModel(LlamaPreTrainedModel):
  57. pass
  58. class Jais2Model(LlamaModel):
  59. def __init__(self, config: Jais2Config):
  60. super().__init__(config)
  61. self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  62. class Jais2ForCausalLM(LlamaForCausalLM):
  63. @can_return_tuple
  64. @auto_docstring
  65. def forward(self, **super_kwargs):
  66. r"""
  67. Example:
  68. ```python
  69. >>> from transformers import AutoTokenizer, Jais2ForCausalLM
  70. >>> model = Jais2ForCausalLM.from_pretrained("inceptionai/Jais-2-8B-Chat")
  71. >>> tokenizer = AutoTokenizer.from_pretrained("inceptionai/Jais-2-8B-Chat")
  72. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  73. >>> inputs = tokenizer(prompt, return_tensors="pt")
  74. >>> # Generate
  75. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  76. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  77. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  78. ```"""
  79. return super().forward(**super_kwargs)
  80. __all__ = [
  81. "Jais2Config",
  82. "Jais2Model",
  83. "Jais2ForCausalLM",
  84. "Jais2PreTrainedModel",
  85. ]