configuration_fsmt.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # Copyright 2019-present, Facebook, Inc and the HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """FSMT configuration"""
  15. from huggingface_hub.dataclasses import strict
  16. from ...configuration_utils import PreTrainedConfig
  17. from ...utils import auto_docstring
  18. @auto_docstring(checkpoint="facebook/wmt19-en-ru")
  19. @strict
  20. class FSMTConfig(PreTrainedConfig):
  21. r"""
  22. langs (`list[str]`):
  23. A list with source language and target_language (e.g., ['en', 'ru']).
  24. src_vocab_size (`int`):
  25. Vocabulary size of the encoder. Defines the number of different tokens that can be represented by the
  26. `inputs_ids` passed to the forward method in the encoder.
  27. tgt_vocab_size (`int`):
  28. Vocabulary size of the decoder. Defines the number of different tokens that can be represented by the
  29. `inputs_ids` passed to the forward method in the decoder.
  30. max_length (`int`, *optional*, defaults to 200):
  31. Maximum length to generate.
  32. num_beams (`int`, *optional*, defaults to 5):
  33. Number of beams for beam search that will be used by default in the `generate` method of the model. 1 means
  34. no beam search.
  35. length_penalty (`float`, *optional*, defaults to 1):
  36. Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
  37. the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
  38. likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
  39. `length_penalty` < 0.0 encourages shorter sequences.
  40. early_stopping (`bool`, *optional*, defaults to `False`):
  41. Flag that will be used by default in the `generate` method of the model. Whether to stop the beam search
  42. when at least `num_beams` sentences are finished per batch or not.
  43. Examples:
  44. ```python
  45. >>> from transformers import FSMTConfig, FSMTModel
  46. >>> # Initializing a FSMT facebook/wmt19-en-ru style configuration
  47. >>> config = FSMTConfig()
  48. >>> # Initializing a model (with random weights) from the configuration
  49. >>> model = FSMTModel(config)
  50. >>> # Accessing the model configuration
  51. >>> configuration = model.config
  52. ```"""
  53. model_type = "fsmt"
  54. attribute_map = {
  55. "num_attention_heads": "encoder_attention_heads",
  56. "hidden_size": "d_model",
  57. "vocab_size": "tgt_vocab_size",
  58. "num_hidden_layers": "encoder_layers",
  59. }
  60. langs: list[str] | tuple[str, ...] = ("en", "de")
  61. src_vocab_size: int = 42024
  62. tgt_vocab_size: int = 42024
  63. activation_function: str = "relu"
  64. d_model: int = 1024
  65. max_length: int = 200
  66. max_position_embeddings: int = 1024
  67. encoder_ffn_dim: int = 4096
  68. encoder_layers: int = 12
  69. encoder_attention_heads: int = 16
  70. encoder_layerdrop: float | int = 0.0
  71. decoder_ffn_dim: int = 4096
  72. decoder_layers: int = 12
  73. decoder_attention_heads: int = 16
  74. decoder_layerdrop: float | int = 0.0
  75. attention_dropout: float | int = 0.0
  76. dropout: float | int = 0.1
  77. activation_dropout: float | int = 0.0
  78. init_std: float = 0.02
  79. decoder_start_token_id: int | None = 2
  80. is_encoder_decoder: bool = True
  81. scale_embedding: bool = True
  82. tie_word_embeddings: bool = False
  83. num_beams: int = 5
  84. length_penalty: float = 1.0
  85. early_stopping: bool = False
  86. use_cache: bool = True
  87. pad_token_id: int | None = 1
  88. bos_token_id: int | None = 0
  89. eos_token_id: int | list[int] | None = 2
  90. forced_eos_token_id: int | list[int] | None = 2
  91. def __post_init__(self, **kwargs):
  92. kwargs.pop("decoder", None) # delete unused kwargs
  93. super().__post_init__(**kwargs)
  94. __all__ = ["FSMTConfig"]