configuration_utils.py 62 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352
  1. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
  2. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Configuration base class and utilities."""
  16. import copy
  17. import json
  18. import math
  19. import os
  20. from collections.abc import Sequence
  21. from dataclasses import MISSING, dataclass, fields
  22. from functools import wraps
  23. from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeVar, Union
  24. from huggingface_hub import create_repo
  25. from huggingface_hub.dataclasses import strict
  26. from packaging import version
  27. from . import __version__
  28. from .dynamic_module_utils import custom_object_save
  29. from .generation.configuration_utils import GenerationConfig
  30. from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
  31. from .modeling_rope_utils import RotaryEmbeddingConfigMixin
  32. from .utils import (
  33. CONFIG_NAME,
  34. PushToHubMixin,
  35. cached_file,
  36. copy_func,
  37. extract_commit_hash,
  38. is_torch_available,
  39. logging,
  40. )
  41. from .utils.generic import is_timm_config_dict
  42. if TYPE_CHECKING:
  43. import torch
  44. logger = logging.get_logger(__name__)
  45. # type hinting: specifying the type of config class that inherits from PreTrainedConfig
  46. SpecificPreTrainedConfigType = TypeVar("SpecificPreTrainedConfigType", bound="PreTrainedConfig")
  47. _FLOAT_TAG_KEY = "__float__"
  48. _FLOAT_TAG_VALUES = {"Infinity": float("inf"), "-Infinity": float("-inf"), "NaN": float("nan")}
  49. ALLOWED_LAYER_TYPES = (
  50. "full_attention",
  51. "sliding_attention",
  52. "chunked_attention",
  53. "linear_attention", # used in minimax
  54. "conv", # used in LFMv2
  55. "mamba",
  56. "attention",
  57. "sparse",
  58. "dense",
  59. "hybrid", # for layers that have both mamba and attention in zamba and zamba2
  60. "moe", # for nemotron_h, which uses either attention, mamba or moe
  61. )
  62. # copied from huggingface_hub.dataclasses.strict when `accept_kwargs=True`
  63. def wrap_init_to_accept_kwargs(cls: dataclass):
  64. original_init = cls.__init__
  65. @wraps(original_init)
  66. def __init__(self, *args, **kwargs: Any) -> None:
  67. # Extract only the fields that are part of the dataclass
  68. dataclass_fields = {f.name for f in fields(cls)}
  69. standard_kwargs = {k: v for k, v in kwargs.items() if k in dataclass_fields}
  70. # We need to call bare `__init__` without `__post_init__` but the `original_init` of
  71. # any dataclas contains a call to post-init at the end (without kwargs)
  72. if len(args) > 0:
  73. raise ValueError(
  74. f"{cls.__name__} accepts only keyword arguments, but found `{len(args)}` positional args."
  75. )
  76. for f in fields(cls): # type: ignore
  77. if f.name in standard_kwargs:
  78. setattr(self, f.name, standard_kwargs[f.name])
  79. elif f.default is not MISSING:
  80. setattr(self, f.name, f.default)
  81. elif f.default_factory is not MISSING:
  82. setattr(self, f.name, f.default_factory())
  83. else:
  84. raise TypeError(f"Missing required field - '{f.name}'")
  85. # Pass any additional kwargs to `__post_init__` and let the object
  86. # decide whether to set the attr or use for different purposes (e.g. BC checks)
  87. additional_kwargs = {}
  88. for name, value in kwargs.items():
  89. if name not in dataclass_fields:
  90. additional_kwargs[name] = value
  91. self.__post_init__(**additional_kwargs)
  92. cls.__init__ = __init__
  93. return cls
  94. @strict(accept_kwargs=True)
  95. @dataclass(repr=False)
  96. class PreTrainedConfig(PushToHubMixin, RotaryEmbeddingConfigMixin):
  97. # no-format
  98. r"""
  99. Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
  100. methods for loading/downloading/saving configurations.
  101. <Tip>
  102. A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to
  103. initialize a model does **not** load the model weights. It only affects the model's configuration.
  104. </Tip>
  105. Class attributes (overridden by derived classes):
  106. - **model_type** (`str`) -- An identifier for the model type, serialized into the JSON file, and used to recreate
  107. the correct object in [`~transformers.AutoConfig`].
  108. - **has_no_defaults_at_init** (`bool`) -- Whether the config class can be initialized without providing input arguments.
  109. Some configurations requires inputs to be defined at init and have no default values, usually these are composite configs,
  110. (but not necessarily) such as [`~transformers.EncoderDecoderConfig`] or [`~RagConfig`]. They have to be initialized from
  111. two or more configs of type [`~transformers.PreTrainedConfig`].
  112. - **keys_to_ignore_at_inference** (`list[str]`) -- A list of keys to ignore by default when looking at dictionary
  113. outputs of the model during inference.
  114. - **attribute_map** (`dict[str, str]`) -- A dict that maps model specific attribute names to the standardized
  115. naming of attributes.
  116. - **base_model_tp_plan** (`dict[str, Any]`) -- A dict that maps sub-modules FQNs of a base model to a tensor
  117. parallel plan applied to the sub-module when `model.tensor_parallel` is called.
  118. - **base_model_pp_plan** (`dict[str, tuple[list[str]]]`) -- A dict that maps child-modules of a base model to a
  119. pipeline parallel plan that enables users to place the child-module on the appropriate device.
  120. Common attributes (present in all subclasses):
  121. - **vocab_size** (`int`) -- The number of tokens in the vocabulary, which is also the first dimension of the
  122. embeddings matrix (this attribute may be missing for models that don't have a text modality like ViT).
  123. - **hidden_size** (`int`) -- The hidden size of the model.
  124. - **num_attention_heads** (`int`) -- The number of attention heads used in the multi-head attention layers of the
  125. model.
  126. - **num_hidden_layers** (`int`) -- The number of blocks in the model.
  127. <Tip warning={true}>
  128. Setting parameters for sequence generation in the model config is deprecated. For backward compatibility, loading
  129. some of them will still be possible, but attempting to overwrite them will throw an exception -- you should set
  130. them in a [~transformers.GenerationConfig]. Check the documentation of [~transformers.GenerationConfig] for more
  131. information about the individual parameters.
  132. </Tip>
  133. Arg:
  134. name_or_path (`str`, *optional*, defaults to `""`):
  135. Store the string that was passed to [`PreTrainedModel.from_pretrained`] as `pretrained_model_name_or_path`
  136. if the configuration was created with such a method.
  137. output_hidden_states (`bool`, *optional*, defaults to `False`):
  138. Whether or not the model should return all hidden-states.
  139. output_attentions (`bool`, *optional*, defaults to `False`):
  140. Whether or not the model should returns all attentions.
  141. return_dict (`bool`, *optional*, defaults to `True`):
  142. Whether or not the model should return a [`~transformers.utils.ModelOutput`] instead of a plain tuple.
  143. is_encoder_decoder (`bool`, *optional*, defaults to `False`):
  144. Whether the model is used as an encoder/decoder or not.
  145. chunk_size_feed_forward (`int`, *optional*, defaults to `0`):
  146. The chunk size of all feed forward layers in the residual attention blocks. A chunk size of `0` means that
  147. the feed forward layer is not chunked. A chunk size of n means that the feed forward layer processes `n` <
  148. sequence_length embeddings at a time. For more information on feed forward chunking, see [How does Feed
  149. Forward Chunking work?](../glossary.html#feed-forward-chunking).
  150. > Parameters for fine-tuning tasks
  151. architectures (`list[str]`, *optional*):
  152. Model architectures that can be used with the model pretrained weights.
  153. id2label (`dict[int, str]`, *optional*):
  154. A map from index (for instance prediction index, or target index) to label.
  155. label2id (`dict[str, int]`, *optional*):
  156. A map from label to index for the model.
  157. num_labels (`int`, *optional*):
  158. Number of labels to use in the last layer added to the model, typically for a classification task.
  159. problem_type (`str`, *optional*):
  160. Problem type for `XxxForSequenceClassification` models. Can be one of `"regression"`,
  161. `"single_label_classification"` or `"multi_label_classification"`.
  162. > PyTorch specific parameters
  163. dtype (`str`, *optional*):
  164. The `dtype` of the weights. This attribute can be used to initialize the model to a non-default `dtype`
  165. (which is normally `float32`) and thus allow for optimal storage allocation. For example, if the saved
  166. model is `float16`, ideally we want to load it back using the minimal amount of memory needed to load
  167. `float16` weights.
  168. """
  169. # Class attributes that we don't want to save or have in `self.__dict__`
  170. # They are not supposed to be set/changed by users. Each field is set when
  171. # creating a model class
  172. base_config_key: ClassVar[str] = ""
  173. sub_configs: ClassVar[dict[str, type["PreTrainedConfig"]]] = {}
  174. has_no_defaults_at_init: ClassVar[bool] = False
  175. keys_to_ignore_at_inference: ClassVar[list[str]] = []
  176. attribute_map: ClassVar[dict[str, str]] = {}
  177. base_model_tp_plan: ClassVar[dict[str, Any] | None] = None
  178. base_model_pp_plan: ClassVar[dict[str, Sequence[list[str]]] | None] = None
  179. base_model_ep_plan: ClassVar[dict[str, Sequence[list[str]]] | None] = None
  180. _auto_class: ClassVar[str | None] = None
  181. # Attributes set internally when saving and used to infer model
  182. # class for `Auto` mapping
  183. model_type: ClassVar[str] = ""
  184. transformers_version: str | None = None
  185. architectures: list[str] | None = None
  186. # Common attributes for all models
  187. output_hidden_states: bool | None = False
  188. return_dict: bool | None = True
  189. dtype: Union[str, "torch.dtype"] | None = None
  190. chunk_size_feed_forward: int = 0
  191. is_encoder_decoder: bool = False
  192. # Fine-tuning task arguments
  193. id2label: dict[int, str] | dict[str, str] | None = None
  194. label2id: dict[str, int] | dict[str, str] | None = None
  195. problem_type: Literal["regression", "single_label_classification", "multi_label_classification"] | None = None
  196. def __post_init__(self, **kwargs):
  197. # BC for the `torch_dtype` argument instead of the simpler `dtype`
  198. # Do not warn, as it would otherwise always be triggered since most configs on the hub have `torch_dtype`
  199. if (torch_dtype := kwargs.pop("torch_dtype", None)) is not None:
  200. # If both are provided, keep `dtype`
  201. self.dtype = self.dtype if self.dtype is not None else torch_dtype
  202. if self.dtype is not None and isinstance(self.dtype, str) and is_torch_available():
  203. # we will start using self.dtype in v5, but to be consistent with
  204. # from_pretrained's dtype arg convert it to an actual torch.dtype object
  205. import torch
  206. self.dtype = getattr(torch, self.dtype)
  207. # Keep the default value of `num_labels=2` in case users have saved a classfier with 2 labels
  208. # Our configs prev wouldn't save `id2label` for 2 labels because it is the default. In all other
  209. # cases we expect the config dict to have an `id2label` field if it's a clf model, or not otherwise
  210. if self.id2label is None:
  211. self.num_labels = kwargs.get("num_labels", 2)
  212. else:
  213. if kwargs.get("num_labels") is not None and len(self.id2label) != kwargs.get("num_labels"):
  214. logger.warning(
  215. f"You passed `num_labels={kwargs.get('num_labels')}` which is incompatible to "
  216. f"the `id2label` map of length `{len(self.id2label)}`."
  217. )
  218. # Keys are always strings in JSON so convert ids to int
  219. self.id2label = {int(key): value for key, value in self.id2label.items()}
  220. # BC for rotary embeddings. We will pop out legacy keys from kwargs and rename to new format
  221. if hasattr(self, "rope_parameters"):
  222. kwargs = self.convert_rope_params_to_dict(**kwargs)
  223. elif kwargs.get("rope_scaling") and kwargs.get("rope_theta"):
  224. logger.warning(
  225. f"{self.__class__.__name__} got `key=rope_scaling` in kwargs but hasn't set it as attribute. "
  226. "For RoPE standardization you need to set `self.rope_parameters` in model's config. "
  227. )
  228. kwargs = self.convert_rope_params_to_dict(**kwargs)
  229. # Parameters for sequence generation saved in the config are popped instead of loading them.
  230. for parameter_name in GenerationConfig._get_default_generation_params().keys():
  231. kwargs.pop(parameter_name, None)
  232. # Name or path to the pretrained checkpoint
  233. self._name_or_path = str(kwargs.pop("name_or_path", ""))
  234. self._commit_hash = kwargs.pop("_commit_hash", None)
  235. # Attention/Experts implementation to use, if relevant (it sets it recursively on sub-configs)
  236. self._output_attentions: bool | None = kwargs.pop("output_attentions", False)
  237. self._attn_implementation: str | None = kwargs.pop("attn_implementation", None)
  238. self._experts_implementation: str | None = kwargs.pop("experts_implementation", None)
  239. # Additional attributes without default values
  240. for key, value in kwargs.items():
  241. # Check this to avoid deserializing problematic fields from hub configs - they should use the public field
  242. if key not in ("_attn_implementation_internal", "_experts_implementation_internal"):
  243. try:
  244. setattr(self, key, value)
  245. except AttributeError as err:
  246. logger.error(f"Can't set {key} with value {value} for {self}")
  247. raise err
  248. def __init_subclass__(cls, *args, **kwargs):
  249. super().__init_subclass__(*args, **kwargs)
  250. cls_has_custom_init = "__init__" in cls.__dict__
  251. # kw_only=True ensures fields without defaults in subclasses can follow
  252. # parent fields that have defaults (Python dataclass ordering rule).
  253. # Config fields are always passed as keyword arguments, so this is safe.
  254. cls = dataclass(cls, repr=False, kw_only=True)
  255. if not cls_has_custom_init:
  256. # Wrap all subclasses to accept arbitrary kwargs for BC
  257. # only if the subclass has no custom `__init__`. Most
  258. # remote code has an init defined, but some model are not
  259. # See https://huggingface.co/hmellor/Ilama-3.2-1B/blob/main/configuration_ilama.py
  260. cls = wrap_init_to_accept_kwargs(cls)
  261. @property
  262. def name_or_path(self) -> str | None:
  263. return getattr(self, "_name_or_path", None)
  264. @name_or_path.setter
  265. def name_or_path(self, value):
  266. self._name_or_path = str(value) # Make sure that name_or_path is a string (for JSON encoding)
  267. @property
  268. def num_labels(self) -> int:
  269. """
  270. `int`: The number of labels for classification models.
  271. """
  272. return len(self.id2label) if self.id2label is not None else None
  273. @num_labels.setter
  274. def num_labels(self, num_labels: int):
  275. # we do not store `num_labels` attribute in config, but instead
  276. # compute it based on the length of the `id2label` map
  277. if self.id2label is None or self.num_labels != num_labels:
  278. self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
  279. self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
  280. @property
  281. def output_attentions(self):
  282. """
  283. `bool`: Whether or not the model should returns all attentions.
  284. """
  285. return self._output_attentions
  286. @output_attentions.setter
  287. def output_attentions(self, value: bool):
  288. # If we set `output_attentions` explicitly before the attn implementation, dispatch eager
  289. if value and self._attn_implementation is None:
  290. self._attn_implementation = "eager"
  291. if value and self._attn_implementation != "eager":
  292. raise ValueError(
  293. "The `output_attentions` attribute is not supported when using the `attn_implementation` set to "
  294. f"{self._attn_implementation}. Please set it to 'eager' instead."
  295. )
  296. self._output_attentions = value
  297. @property
  298. def _attn_implementation(self):
  299. return self._attn_implementation_internal
  300. @_attn_implementation.setter
  301. def _attn_implementation(self, value: str | dict | None):
  302. """We set it recursively on the sub-configs as well"""
  303. # Set if for current config
  304. current_attn = getattr(self, "_attn_implementation", None)
  305. attn_implementation = value if not isinstance(value, dict) else value.get("", current_attn)
  306. self._attn_implementation_internal = attn_implementation
  307. # Set it recursively on the subconfigs
  308. for subconfig_key in self.sub_configs:
  309. subconfig = getattr(self, subconfig_key, None)
  310. if subconfig is not None:
  311. current_subconfig_attn = getattr(subconfig, "_attn_implementation", None)
  312. sub_implementation = (
  313. value if not isinstance(value, dict) else value.get(subconfig_key, current_subconfig_attn)
  314. )
  315. subconfig._attn_implementation = sub_implementation
  316. @property
  317. def _experts_implementation(self):
  318. return self._experts_implementation_internal
  319. @_experts_implementation.setter
  320. def _experts_implementation(self, value: str | dict | None):
  321. """We set it recursively on the sub-configs as well"""
  322. # Set if for current config
  323. current_moe = getattr(self, "_experts_implementation", None)
  324. experts_implementation = value if not isinstance(value, dict) else value.get("", current_moe)
  325. self._experts_implementation_internal = experts_implementation
  326. # Set it recursively on the subconfigs
  327. for subconfig_key in self.sub_configs:
  328. subconfig = getattr(self, subconfig_key, None)
  329. if subconfig is not None:
  330. current_subconfig_moe = getattr(subconfig, "_experts_implementation", None)
  331. sub_implementation = (
  332. value if not isinstance(value, dict) else value.get(subconfig_key, current_subconfig_moe)
  333. )
  334. subconfig._experts_implementation = sub_implementation
  335. @property
  336. def torch_dtype(self):
  337. logger.warning_once("`torch_dtype` is deprecated! Use `dtype` instead!")
  338. return self.dtype
  339. @property
  340. def use_return_dict(self):
  341. logger.warning_once("`use_return_dict` is deprecated! Use `return_dict` instead!")
  342. return self.return_dict
  343. @torch_dtype.setter
  344. def torch_dtype(self, value):
  345. logger.warning_once("`torch_dtype` is deprecated! Use `dtype` instead!")
  346. self.dtype = value
  347. def __setattr__(self, key, value):
  348. if key in super().__getattribute__("attribute_map"):
  349. key = super().__getattribute__("attribute_map")[key]
  350. super().__setattr__(key, value)
  351. def __getattribute__(self, key):
  352. if key != "attribute_map" and key in super().__getattribute__("attribute_map"):
  353. key = super().__getattribute__("attribute_map")[key]
  354. return super().__getattribute__(key)
  355. def validate_output_attentions(self):
  356. if self.output_attentions and self._attn_implementation not in ["eager", None]:
  357. raise ValueError(
  358. "The `output_attentions` attribute is not supported when using the `attn_implementation` set to "
  359. f"{self._attn_implementation}. Please set it to 'eager' instead."
  360. )
  361. def validate_architecture(self):
  362. """Part of `@strict`-powered validation. Validates the architecture of the config."""
  363. if (
  364. hasattr(self, "head_dim")
  365. and hasattr(self, "num_heads")
  366. and hasattr(self, "embed_dim")
  367. and self.head_dim * self.num_heads != self.embed_dim
  368. ):
  369. raise ValueError(
  370. f"The embed_dim ({self.embed_dim}) is not a multiple of the number of attention "
  371. f"heads ({self.num_heads})."
  372. )
  373. def validate_token_ids(self):
  374. """Part of `@strict`-powered validation. Validates the contents of the special tokens."""
  375. text_config = self.get_text_config(decoder=True)
  376. vocab_size = getattr(text_config, "vocab_size", None)
  377. if vocab_size is not None:
  378. # Check for all special tokens, e..g. pad_token_id, image_token_id, audio_token_id
  379. for value in text_config:
  380. if value.endswith("_token_id") and isinstance(value, int) and not 0 <= value < vocab_size:
  381. # Can't be an exception until we can load configs that fail validation: several configs on the Hub
  382. # store invalid special tokens, e.g. `pad_token_id=-1`
  383. logger.warning_once(
  384. f"Model config: {value} must be `None` or an integer within the vocabulary (between 0 "
  385. f"and {vocab_size - 1}), got {value}. This may result in unexpected behavior."
  386. )
  387. def validate_layer_type(self):
  388. """Check that `layer_types` is correctly defined."""
  389. if not (getattr(self, "layer_types", None) is not None and hasattr(self, "num_hidden_layers")):
  390. return
  391. elif not all(layer_type in ALLOWED_LAYER_TYPES for layer_type in self.layer_types):
  392. raise ValueError(f"The `layer_types` entries must be in {ALLOWED_LAYER_TYPES} but got {self.layer_types}")
  393. elif self.num_hidden_layers is not None and self.num_hidden_layers != len(self.layer_types):
  394. raise ValueError(
  395. f"`num_hidden_layers` ({self.num_hidden_layers}) must be equal to the number of layer types "
  396. f"({len(self.layer_types)})"
  397. )
  398. @property
  399. def rope_scaling(self):
  400. return self.rope_parameters
  401. @rope_scaling.setter
  402. def rope_scaling(self, value):
  403. self.rope_parameters = value
  404. def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs):
  405. """
  406. Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
  407. [`~PreTrainedConfig.from_pretrained`] class method.
  408. Args:
  409. save_directory (`str` or `os.PathLike`):
  410. Directory where the configuration JSON file will be saved (will be created if it does not exist).
  411. push_to_hub (`bool`, *optional*, defaults to `False`):
  412. Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
  413. repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
  414. namespace).
  415. kwargs (`dict[str, Any]`, *optional*):
  416. Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
  417. """
  418. if os.path.isfile(save_directory):
  419. raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
  420. generation_parameters = self._get_generation_parameters()
  421. if len(generation_parameters) > 0:
  422. raise ValueError(
  423. "Some generation parameters are set in the model config. These should go into `model.generation_config`"
  424. f"as opposed to `model.config`. \nGeneration parameters found: {str(generation_parameters)}",
  425. )
  426. os.makedirs(save_directory, exist_ok=True)
  427. if push_to_hub:
  428. commit_message = kwargs.pop("commit_message", None)
  429. repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
  430. repo_id = create_repo(repo_id, exist_ok=True, **kwargs).repo_id
  431. files_timestamps = self._get_files_timestamps(save_directory)
  432. # This attribute is important to know on load, but should not be serialized on save.
  433. if "transformers_weights" in self:
  434. delattr(self, "transformers_weights")
  435. # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
  436. # loaded from the Hub.
  437. if self._auto_class is not None:
  438. custom_object_save(self, save_directory, config=self)
  439. # If we save using the predefined names, we can load using `from_pretrained`
  440. output_config_file = os.path.join(save_directory, CONFIG_NAME)
  441. # Strict validation at save-time: prevent bad patterns from propagating
  442. # Using `strict` decorator guarantees that `self.validate` exists , but not all
  443. # model config might have the decorator added
  444. if hasattr(self, "validate"):
  445. self.validate()
  446. self.to_json_file(output_config_file, use_diff=True)
  447. logger.info(f"Configuration saved in {output_config_file}")
  448. if push_to_hub:
  449. self._upload_modified_files(
  450. save_directory,
  451. repo_id,
  452. files_timestamps,
  453. commit_message=commit_message,
  454. token=kwargs.get("token"),
  455. )
  456. @classmethod
  457. def from_pretrained(
  458. cls: type[SpecificPreTrainedConfigType],
  459. pretrained_model_name_or_path: str | os.PathLike,
  460. cache_dir: str | os.PathLike | None = None,
  461. force_download: bool = False,
  462. local_files_only: bool = False,
  463. token: str | bool | None = None,
  464. revision: str = "main",
  465. **kwargs,
  466. ) -> SpecificPreTrainedConfigType:
  467. r"""
  468. Instantiate a [`PreTrainedConfig`] (or a derived class) from a pretrained model configuration.
  469. Args:
  470. pretrained_model_name_or_path (`str` or `os.PathLike`):
  471. This can be either:
  472. - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
  473. huggingface.co.
  474. - a path to a *directory* containing a configuration file saved using the
  475. [`~PreTrainedConfig.save_pretrained`] method, e.g., `./my_model_directory/`.
  476. - a path to a saved configuration JSON *file*, e.g., `./my_model_directory/configuration.json`.
  477. cache_dir (`str` or `os.PathLike`, *optional*):
  478. Path to a directory in which a downloaded pretrained model configuration should be cached if the
  479. standard cache should not be used.
  480. force_download (`bool`, *optional*, defaults to `False`):
  481. Whether or not to force to (re-)download the configuration files and override the cached versions if
  482. they exist.
  483. proxies (`dict[str, str]`, *optional*):
  484. A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
  485. 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
  486. token (`str` or `bool`, *optional*):
  487. The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
  488. the token generated when running `hf auth login` (stored in `~/.huggingface`).
  489. revision (`str`, *optional*, defaults to `"main"`):
  490. The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
  491. git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
  492. identifier allowed by git.
  493. <Tip>
  494. To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
  495. </Tip>
  496. return_unused_kwargs (`bool`, *optional*, defaults to `False`):
  497. If `False`, then this function returns just the final configuration object.
  498. If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
  499. dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
  500. part of `kwargs` which has not been used to update `config` and is otherwise ignored.
  501. subfolder (`str`, *optional*, defaults to `""`):
  502. In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
  503. specify the folder name here.
  504. kwargs (`dict[str, Any]`, *optional*):
  505. The values in kwargs of any keys which are configuration attributes will be used to override the loaded
  506. values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
  507. by the `return_unused_kwargs` keyword parameter.
  508. Returns:
  509. [`PreTrainedConfig`]: The configuration object instantiated from this pretrained model.
  510. Examples:
  511. ```python
  512. # We can't instantiate directly the base class *PreTrainedConfig* so let's show the examples on a
  513. # derived class: BertConfig
  514. config = BertConfig.from_pretrained(
  515. "google-bert/bert-base-uncased"
  516. ) # Download configuration from huggingface.co and cache.
  517. config = BertConfig.from_pretrained(
  518. "./test/saved_model/"
  519. ) # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')*
  520. config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json")
  521. config = BertConfig.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False)
  522. assert config.output_attentions == True
  523. config, unused_kwargs = BertConfig.from_pretrained(
  524. "google-bert/bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
  525. )
  526. assert config.output_attentions == True
  527. assert unused_kwargs == {"foo": False}
  528. ```"""
  529. kwargs["cache_dir"] = cache_dir
  530. kwargs["force_download"] = force_download
  531. kwargs["local_files_only"] = local_files_only
  532. kwargs["revision"] = revision
  533. config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
  534. if cls.base_config_key and cls.base_config_key in config_dict:
  535. config_dict = config_dict[cls.base_config_key]
  536. if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
  537. # sometimes the config has no `base_config_key` if the config is used in several composite models
  538. # e.g. LlamaConfig. In that case we try to see if there is match in `model_type` before raising a warning
  539. for v in config_dict.values():
  540. if isinstance(v, dict) and v.get("model_type") == cls.model_type:
  541. config_dict = v
  542. # raise warning only if we still can't see a match in `model_type`
  543. if config_dict["model_type"] != cls.model_type:
  544. logger.warning(
  545. f"You are using a model of type `{config_dict['model_type']}` to instantiate a model of type "
  546. f"`{cls.model_type}`. This may be expected if you are loading a checkpoint that shares a subset "
  547. f"of the architecture (e.g., loading a `sam2_video` checkpoint into `Sam2Model`), but is otherwise "
  548. f"not supported and can yield errors. Please verify that the checkpoint is compatible with the "
  549. f"model you are instantiating."
  550. )
  551. return cls.from_dict(config_dict, **kwargs)
  552. @classmethod
  553. def get_config_dict(
  554. cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs
  555. ) -> tuple[dict[str, Any], dict[str, Any]]:
  556. """
  557. From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
  558. [`PreTrainedConfig`] using `from_dict`.
  559. Parameters:
  560. pretrained_model_name_or_path (`str` or `os.PathLike`):
  561. The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
  562. Returns:
  563. `tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object.
  564. """
  565. original_kwargs = copy.deepcopy(kwargs)
  566. # Get config dict associated with the base config file
  567. config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
  568. if config_dict is None:
  569. return {}, kwargs
  570. if "_commit_hash" in config_dict:
  571. original_kwargs["_commit_hash"] = config_dict["_commit_hash"]
  572. # That config file may point us toward another config file to use.
  573. if "configuration_files" in config_dict:
  574. configuration_file = get_configuration_file(config_dict["configuration_files"])
  575. config_dict, kwargs = cls._get_config_dict(
  576. pretrained_model_name_or_path, _configuration_file=configuration_file, **original_kwargs
  577. )
  578. return config_dict, kwargs
  579. @classmethod
  580. def _get_config_dict(
  581. cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs
  582. ) -> tuple[dict[str, Any], dict[str, Any]]:
  583. cache_dir = kwargs.pop("cache_dir", None)
  584. force_download = kwargs.pop("force_download", False)
  585. proxies = kwargs.pop("proxies", None)
  586. token = kwargs.pop("token", None)
  587. local_files_only = kwargs.pop("local_files_only", False)
  588. revision = kwargs.pop("revision", None)
  589. trust_remote_code = kwargs.pop("trust_remote_code", None)
  590. subfolder = kwargs.pop("subfolder", "")
  591. from_pipeline = kwargs.pop("_from_pipeline", None)
  592. from_auto_class = kwargs.pop("_from_auto", False)
  593. commit_hash = kwargs.pop("_commit_hash", None)
  594. gguf_file = kwargs.get("gguf_file")
  595. if trust_remote_code is True:
  596. logger.warning(
  597. "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
  598. " ignored."
  599. )
  600. user_agent = {"file_type": "config", "from_auto_class": from_auto_class}
  601. if from_pipeline is not None:
  602. user_agent["using_pipeline"] = from_pipeline
  603. pretrained_model_name_or_path = str(pretrained_model_name_or_path)
  604. is_local = os.path.isdir(pretrained_model_name_or_path)
  605. if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
  606. # Special case when pretrained_model_name_or_path is a local file
  607. resolved_config_file = pretrained_model_name_or_path
  608. is_local = True
  609. else:
  610. configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME) if gguf_file is None else gguf_file
  611. try:
  612. # Load from local folder or from cache or download from model Hub and cache
  613. resolved_config_file = cached_file(
  614. pretrained_model_name_or_path,
  615. configuration_file,
  616. cache_dir=cache_dir,
  617. force_download=force_download,
  618. proxies=proxies,
  619. local_files_only=local_files_only,
  620. token=token,
  621. user_agent=user_agent,
  622. revision=revision,
  623. subfolder=subfolder,
  624. _commit_hash=commit_hash,
  625. )
  626. if resolved_config_file is None:
  627. return None, kwargs
  628. commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
  629. except OSError:
  630. # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
  631. # the original exception.
  632. raise
  633. except Exception:
  634. # For any other exception, we throw a generic error.
  635. raise OSError(
  636. f"Can't load the configuration of '{pretrained_model_name_or_path}'. If you were trying to load it"
  637. " from 'https://huggingface.co/models', make sure you don't have a local directory with the same"
  638. f" name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory"
  639. f" containing a {configuration_file} file"
  640. )
  641. try:
  642. if gguf_file:
  643. config_dict = load_gguf_checkpoint(resolved_config_file, return_tensors=False)["config"]
  644. else:
  645. # Load config dict
  646. config_dict = cls._dict_from_json_file(resolved_config_file)
  647. config_dict["_commit_hash"] = commit_hash
  648. except (json.JSONDecodeError, UnicodeDecodeError):
  649. raise OSError(f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file.")
  650. if is_local:
  651. logger.info(f"loading configuration file {resolved_config_file}")
  652. else:
  653. logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")
  654. # timm models are not saved with the model_type in the config file
  655. if "model_type" not in config_dict and is_timm_config_dict(config_dict):
  656. config_dict["model_type"] = "timm_wrapper"
  657. # Some checkpoints may contain the wrong model_type in the config file.
  658. # Allow the user to override it but warn them that it might not work.
  659. if "model_type" in kwargs and config_dict["model_type"] != kwargs["model_type"]:
  660. logger.warning(
  661. f"{configuration_file} has 'model_type={config_dict['model_type']}' but you overrode "
  662. f"it with 'model_type={kwargs['model_type']}'. This may lead to unexpected behavior."
  663. )
  664. config_dict["model_type"] = kwargs["model_type"]
  665. return config_dict, kwargs
  666. @classmethod
  667. def from_dict(
  668. cls: type[SpecificPreTrainedConfigType], config_dict: dict[str, Any], **kwargs
  669. ) -> SpecificPreTrainedConfigType:
  670. """
  671. Instantiates a [`PreTrainedConfig`] from a Python dictionary of parameters.
  672. Args:
  673. config_dict (`dict[str, Any]`):
  674. Dictionary that will be used to instantiate the configuration object. Such a dictionary can be
  675. retrieved from a pretrained checkpoint by leveraging the [`~PreTrainedConfig.get_config_dict`] method.
  676. kwargs (`dict[str, Any]`):
  677. Additional parameters from which to initialize the configuration object.
  678. Returns:
  679. [`PreTrainedConfig`]: The configuration object instantiated from those parameters.
  680. """
  681. return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
  682. # The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
  683. if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
  684. kwargs.setdefault("_commit_hash", config_dict["_commit_hash"])
  685. # To remove arg here are those passed along for our internal telemetry but we still need to remove them
  686. to_remove = ["_from_auto", "_from_pipeline"]
  687. valid_fields = [
  688. "num_labels",
  689. "attn_implementation",
  690. "experts_implementation",
  691. "output_attentions",
  692. "torch_dtype",
  693. "dtype",
  694. "name_or_path",
  695. ]
  696. for key, value in kwargs.items():
  697. if key in valid_fields:
  698. if key not in ["torch_dtype", "dtype"]:
  699. config_dict[key] = value
  700. to_remove.append(key)
  701. elif value != "auto":
  702. config_dict[key] = value
  703. config = cls(**config_dict)
  704. for key, value in kwargs.items():
  705. if hasattr(config, key):
  706. current_attr = getattr(config, key)
  707. # To authorize passing a custom subconfig as kwarg in models that have nested configs.
  708. # We need to update only custom kwarg values instead and keep other attr in subconfig.
  709. if isinstance(current_attr, PreTrainedConfig) and isinstance(value, dict):
  710. current_attr_updated = current_attr.to_dict()
  711. current_attr_updated.update(value)
  712. value = current_attr.__class__(**current_attr_updated)
  713. setattr(config, key, value)
  714. to_remove.append(key)
  715. for key in to_remove:
  716. kwargs.pop(key, None)
  717. logger.info(f"Model config {config}")
  718. if return_unused_kwargs:
  719. return config, kwargs
  720. else:
  721. return config
  722. @classmethod
  723. def from_json_file(
  724. cls: type[SpecificPreTrainedConfigType], json_file: str | os.PathLike
  725. ) -> SpecificPreTrainedConfigType:
  726. """
  727. Instantiates a [`PreTrainedConfig`] from the path to a JSON file of parameters.
  728. Args:
  729. json_file (`str` or `os.PathLike`):
  730. Path to the JSON file containing the parameters.
  731. Returns:
  732. [`PreTrainedConfig`]: The configuration object instantiated from that JSON file.
  733. """
  734. config_dict = cls._dict_from_json_file(json_file)
  735. return cls(**config_dict)
  736. @classmethod
  737. def _dict_from_json_file(cls, json_file: str | os.PathLike):
  738. with open(json_file, encoding="utf-8") as reader:
  739. text = reader.read()
  740. config_dict = json.loads(text)
  741. return cls._decode_special_floats(config_dict)
  742. @classmethod
  743. def _encode_special_floats(cls, obj: Any) -> Any:
  744. """
  745. Iterates over the passed object and encode specific floats that cannot be JSON-serialized. Python's JSON
  746. engine saves floats like `Infinity` (+/-) or `NaN` which are not compatible with other JSON engines.
  747. It serializes floats like `Infinity` as an object: `{'__float__': Infinity}`.
  748. """
  749. if isinstance(obj, float):
  750. if math.isnan(obj):
  751. return {_FLOAT_TAG_KEY: "NaN"}
  752. if obj == float("inf"):
  753. return {_FLOAT_TAG_KEY: "Infinity"}
  754. if obj == float("-inf"):
  755. return {_FLOAT_TAG_KEY: "-Infinity"}
  756. return obj
  757. if isinstance(obj, dict):
  758. return {k: cls._encode_special_floats(v) for k, v in obj.items()}
  759. if isinstance(obj, (list, tuple)):
  760. return [cls._encode_special_floats(v) for v in obj]
  761. return obj
  762. @classmethod
  763. def _decode_special_floats(cls, obj: Any) -> Any:
  764. """
  765. Iterates over the passed object and decode specific floats that cannot be JSON-serialized. Python's JSON
  766. engine saves floats like `Infinity` (+/-) or `NaN` which are not compatible with other JSON engines.
  767. This method deserializes objects like `{'__float__': Infinity}` to their float values like `Infinity`.
  768. """
  769. if isinstance(obj, dict):
  770. if set(obj.keys()) == {_FLOAT_TAG_KEY} and isinstance(obj[_FLOAT_TAG_KEY], str):
  771. tag = obj[_FLOAT_TAG_KEY]
  772. if tag in _FLOAT_TAG_VALUES:
  773. return _FLOAT_TAG_VALUES[tag]
  774. return obj
  775. return {k: cls._decode_special_floats(v) for k, v in obj.items()}
  776. if isinstance(obj, list):
  777. return [cls._decode_special_floats(v) for v in obj]
  778. return obj
  779. def __eq__(self, other):
  780. return isinstance(other, PreTrainedConfig) and (self.__dict__ == other.__dict__)
  781. def __repr__(self):
  782. return f"{self.__class__.__name__} {self.to_json_string()}"
  783. def __iter__(self):
  784. yield from self.__dict__
  785. def to_diff_dict(self) -> dict[str, Any]:
  786. """
  787. Removes all attributes from the configuration that correspond to the default config attributes for
  788. better readability, while always retaining the `config` attribute from the class. Serializes to a
  789. Python dictionary.
  790. Returns:
  791. dict[str, Any]: Dictionary of all the attributes that make up this configuration instance.
  792. """
  793. config_dict = self.to_dict()
  794. # Get the default config dict (from a fresh PreTrainedConfig instance)
  795. default_config_dict = PreTrainedConfig().to_dict()
  796. # get class specific config dict
  797. class_config_dict = self.__class__().to_dict() if not self.has_no_defaults_at_init else {}
  798. serializable_config_dict = {}
  799. # Only serialize values that differ from the default config,
  800. # except always keep the 'config' attribute.
  801. for key, value in config_dict.items():
  802. if (
  803. isinstance(getattr(self, key, None), PreTrainedConfig)
  804. and key in class_config_dict
  805. and isinstance(class_config_dict[key], dict)
  806. ):
  807. # For nested configs we need to clean the diff recursively
  808. diff = recursive_diff_dict(value, default_config_dict, config_obj=getattr(self, key, None))
  809. if "model_type" in value:
  810. # Needs to be set even if it's not in the diff
  811. diff["model_type"] = value["model_type"]
  812. serializable_config_dict[key] = diff
  813. elif (
  814. key not in default_config_dict
  815. or key == "transformers_version"
  816. or key == "vocab_file"
  817. or value != default_config_dict[key]
  818. or (key in default_config_dict and value != class_config_dict.get(key, value))
  819. ):
  820. serializable_config_dict[key] = value
  821. self._remove_keys_not_serialized(serializable_config_dict)
  822. # Key removed only in diff dict
  823. if "_name_or_path" in serializable_config_dict:
  824. del serializable_config_dict["_name_or_path"]
  825. if hasattr(self, "quantization_config"):
  826. serializable_config_dict["quantization_config"] = (
  827. self.quantization_config.to_dict()
  828. if not isinstance(self.quantization_config, dict) and self.quantization_config is not None
  829. else self.quantization_config
  830. )
  831. self.dict_dtype_to_str(serializable_config_dict)
  832. return serializable_config_dict
  833. def to_dict(self) -> dict[str, Any]:
  834. """
  835. Serializes this instance to a Python dictionary.
  836. Returns:
  837. `dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
  838. """
  839. output = copy.deepcopy(self.__dict__)
  840. if hasattr(self.__class__, "model_type"):
  841. output["model_type"] = self.__class__.model_type
  842. # Transformers version when serializing the model
  843. output["transformers_version"] = __version__
  844. # Pop "kwargs" since they are unpacked and set in the post init
  845. output.pop("kwargs", None)
  846. def to_list(value):
  847. if isinstance(value, tuple):
  848. value = [to_list(item) for item in value]
  849. return value
  850. for key, value in output.items():
  851. # Deal with nested configs like CLIP
  852. if isinstance(value, PreTrainedConfig):
  853. value = value.to_dict()
  854. del value["transformers_version"]
  855. # Some models have defaults as tuples because dataclass
  856. # doesn't allow mutables. Let's convert back to `list``
  857. elif isinstance(value, tuple):
  858. value = to_list(value)
  859. output[key] = value
  860. self._remove_keys_not_serialized(output)
  861. if hasattr(self, "quantization_config"):
  862. output["quantization_config"] = (
  863. self.quantization_config.to_dict()
  864. if not isinstance(self.quantization_config, dict) and self.quantization_config is not None
  865. else self.quantization_config
  866. )
  867. self.dict_dtype_to_str(output)
  868. return output
  869. def to_json_string(self, use_diff: bool = True) -> str:
  870. """
  871. Serializes this instance to a JSON string.
  872. Args:
  873. use_diff (`bool`, *optional*, defaults to `True`):
  874. If set to `True`, only the difference between the config instance and the default `PreTrainedConfig()`
  875. is serialized to JSON string.
  876. Returns:
  877. `str`: String containing all the attributes that make up this configuration instance in JSON format.
  878. """
  879. if use_diff is True:
  880. config_dict = self.to_diff_dict()
  881. else:
  882. config_dict = self.to_dict()
  883. # Handle +/-Infinity and NaNs
  884. config_dict = self._encode_special_floats(config_dict)
  885. return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
  886. def to_json_file(self, json_file_path: str | os.PathLike, use_diff: bool = True):
  887. """
  888. Save this instance to a JSON file.
  889. Args:
  890. json_file_path (`str` or `os.PathLike`):
  891. Path to the JSON file in which this configuration instance's parameters will be saved.
  892. use_diff (`bool`, *optional*, defaults to `True`):
  893. If set to `True`, only the difference between the config instance and the default `PreTrainedConfig()`
  894. is serialized to JSON file.
  895. """
  896. with open(json_file_path, "w", encoding="utf-8") as writer:
  897. writer.write(self.to_json_string(use_diff=use_diff))
  898. def update(self, config_dict: dict[str, Any]):
  899. """
  900. Updates attributes of this class with attributes from `config_dict`.
  901. Args:
  902. config_dict (`dict[str, Any]`): Dictionary of attributes that should be updated for this class.
  903. """
  904. for key, value in config_dict.items():
  905. setattr(self, key, value)
  906. def update_from_string(self, update_str: str):
  907. """
  908. Updates attributes of this class with attributes from `update_str`.
  909. The expected format is ints, floats and strings as is, and for booleans use `true` or `false`. For example:
  910. "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
  911. The keys to change have to already exist in the config object.
  912. Args:
  913. update_str (`str`): String with attributes that should be updated for this class.
  914. """
  915. d = dict(x.split("=") for x in update_str.split(","))
  916. for k, v in d.items():
  917. if not hasattr(self, k):
  918. raise ValueError(f"key {k} isn't in the original config dict")
  919. old_v = getattr(self, k)
  920. if isinstance(old_v, bool):
  921. if v.lower() in ["true", "1", "y", "yes"]:
  922. v = True
  923. elif v.lower() in ["false", "0", "n", "no"]:
  924. v = False
  925. else:
  926. raise ValueError(f"can't derive true or false from {v} (key {k})")
  927. elif isinstance(old_v, int):
  928. v = int(v)
  929. elif isinstance(old_v, float):
  930. v = float(v)
  931. elif not isinstance(old_v, str):
  932. raise TypeError(
  933. f"You can only update int, float, bool or string values in the config, got {v} for key {k}"
  934. )
  935. setattr(self, k, v)
  936. def dict_dtype_to_str(self, d: dict[str, Any]) -> None:
  937. """
  938. Checks whether the passed dictionary and its nested dicts have a *dtype* key and if it's not None,
  939. converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
  940. string, which can then be stored in the json format.
  941. """
  942. if d.get("dtype") is not None:
  943. if isinstance(d["dtype"], dict):
  944. d["dtype"] = {k: str(v).split(".")[-1] for k, v in d["dtype"].items()}
  945. # models like Emu3 can have "dtype" as token in config's vocabulary map,
  946. # so we also exclude int type here to avoid error in this special case.
  947. elif not isinstance(d["dtype"], (str, int)):
  948. d["dtype"] = str(d["dtype"]).split(".")[1]
  949. for value in d.values():
  950. if isinstance(value, dict):
  951. self.dict_dtype_to_str(value)
  952. def _remove_keys_not_serialized(self, d: dict[str, Any]) -> None:
  953. """
  954. Checks and removes if there are any keys in the dict that should not be serialized when saving the config.
  955. Runs recursive check on the dict, to remove from all sub configs.
  956. """
  957. for key_to_remove in [
  958. "_is_quantized",
  959. "_auto_class",
  960. "_commit_hash",
  961. "_attn_implementation_internal",
  962. "_experts_implementation_internal",
  963. "ignore_keys_at_rope_validation",
  964. "base_model_tp_plan",
  965. "base_model_pp_plan",
  966. ]:
  967. d.pop(key_to_remove, None)
  968. if "_output_attentions" in d:
  969. d["output_attentions"] = d.pop("_output_attentions")
  970. for value in d.values():
  971. if isinstance(value, dict):
  972. self._remove_keys_not_serialized(value)
  973. @classmethod
  974. def register_for_auto_class(cls, auto_class="AutoConfig"):
  975. """
  976. Register this class with a given auto class. This should only be used for custom configurations as the ones in
  977. the library are already mapped with `AutoConfig`.
  978. Args:
  979. auto_class (`str` or `type`, *optional*, defaults to `"AutoConfig"`):
  980. The auto class to register this new configuration with.
  981. """
  982. if not isinstance(auto_class, str):
  983. auto_class = auto_class.__name__
  984. import transformers.models.auto as auto_module
  985. if not hasattr(auto_module, auto_class):
  986. raise ValueError(f"{auto_class} is not a valid auto class.")
  987. cls._auto_class = auto_class
  988. def _get_generation_parameters(self) -> dict[str, Any]:
  989. """
  990. Checks if there are generation parameters in `PreTrainedConfig` instance. Note that
  991. we should not save generation params in PreTrainedConfig, and we will raise error
  992. if there are any.
  993. """
  994. generation_params = {}
  995. default_config = self.__class__().to_dict() if not self.has_no_defaults_at_init else {}
  996. for key in GenerationConfig._get_default_generation_params().keys():
  997. if key == "use_cache":
  998. continue # common key for most models
  999. if hasattr(self, key) and getattr(self, key) is not None and key not in default_config:
  1000. generation_params[key] = getattr(self, key)
  1001. return generation_params
  1002. def get_text_config(self, decoder=None, encoder=None) -> "PreTrainedConfig":
  1003. """
  1004. Returns the text config related to the text input (encoder) or text output (decoder) of the model. The
  1005. `decoder` and `encoder` input arguments can be used to specify which end of the model we are interested in,
  1006. which is useful on models that have both text input and output modalities.
  1007. There are three possible outcomes of using this method:
  1008. 1. On most models, it returns the original config instance itself.
  1009. 2. On newer (2024+) composite models, it returns the text section of the config, which is nested under a set
  1010. of valid names.
  1011. 3. On older (2023-) composite models, it discards decoder-only parameters when `encoder=True` and vice-versa.
  1012. Args:
  1013. decoder (`Optional[bool]`, *optional*):
  1014. If set to `True`, then only search for decoder config names.
  1015. encoder (`Optional[bool]`, *optional*):
  1016. If set to `True`, then only search for encoder config names.
  1017. """
  1018. return_both = decoder == encoder # both unset or both set -> search all possible names
  1019. decoder_possible_text_config_names = ("decoder", "generator", "text_config")
  1020. encoder_possible_text_config_names = ("text_encoder",)
  1021. if return_both:
  1022. possible_text_config_names = encoder_possible_text_config_names + decoder_possible_text_config_names
  1023. elif decoder:
  1024. possible_text_config_names = decoder_possible_text_config_names
  1025. else:
  1026. possible_text_config_names = encoder_possible_text_config_names
  1027. valid_text_config_names = []
  1028. for text_config_name in possible_text_config_names:
  1029. if hasattr(self, text_config_name):
  1030. text_config = getattr(self, text_config_name, None)
  1031. if text_config is not None:
  1032. valid_text_config_names += [text_config_name]
  1033. if len(valid_text_config_names) > 1:
  1034. raise ValueError(
  1035. f"Multiple valid text configs were found in the model config: {valid_text_config_names}. In this "
  1036. "case, using `get_text_config()` would be ambiguous. Please specify the desired text config directly, "
  1037. "e.g. `text_config = config.sub_config_name`"
  1038. )
  1039. elif len(valid_text_config_names) == 1:
  1040. config_to_return = getattr(self, valid_text_config_names[0])
  1041. else:
  1042. config_to_return = self
  1043. # handle legacy models with flat config structure, when we only want one of the configs
  1044. if not return_both and len(valid_text_config_names) == 0 and config_to_return.is_encoder_decoder:
  1045. config_to_return = copy.deepcopy(config_to_return)
  1046. prefix_to_keep = "decoder" if decoder else "encoder"
  1047. for key in config_to_return.to_dict():
  1048. # NOTE: We can't discard keys because:
  1049. # 1) we can't truly delete a cls attribte on a dataclass; 2) we can't set the value to `None` due to
  1050. # strict validation. So we just keep it as is, since there are only a couple old models falling in this condition
  1051. if key.startswith(prefix_to_keep):
  1052. # [encoder/decoder]_layers -> num_hidden_layers
  1053. if key == prefix_to_keep + "_layers":
  1054. new_key = "num_hidden_layers"
  1055. # [encoder/decoder]_attention_heads -> num_attention_heads
  1056. elif key == prefix_to_keep + "_attention_heads":
  1057. new_key = "num_attention_heads"
  1058. # e.g. encoder_hidden_act -> hidden_act
  1059. else:
  1060. new_key = key[len(prefix_to_keep) + 1 :]
  1061. # Does the class map the new key into a different attribute name at read time? if so, let's write
  1062. # into that attribute instead
  1063. if new_key in config_to_return.attribute_map:
  1064. new_key = config_to_return.attribute_map[new_key]
  1065. value = getattr(config_to_return, key)
  1066. delattr(config_to_return, key)
  1067. setattr(config_to_return, new_key, value)
  1068. return config_to_return
  1069. def get_configuration_file(configuration_files: list[str]) -> str:
  1070. """
  1071. Get the configuration file to use for this version of transformers.
  1072. Args:
  1073. configuration_files (`list[str]`): The list of available configuration files.
  1074. Returns:
  1075. `str`: The configuration file to use.
  1076. """
  1077. configuration_files_map = {}
  1078. for file_name in configuration_files:
  1079. if file_name.startswith("config.") and file_name.endswith(".json") and file_name != "config.json":
  1080. v = file_name.removeprefix("config.").removesuffix(".json")
  1081. configuration_files_map[v] = file_name
  1082. available_versions = sorted(configuration_files_map.keys())
  1083. # Defaults to FULL_CONFIGURATION_FILE and then try to look at some newer versions.
  1084. configuration_file = CONFIG_NAME
  1085. transformers_version = version.parse(__version__)
  1086. for v in available_versions:
  1087. if version.parse(v) <= transformers_version:
  1088. configuration_file = configuration_files_map[v]
  1089. else:
  1090. # No point going further since the versions are sorted.
  1091. break
  1092. return configuration_file
  1093. def recursive_diff_dict(dict_a, dict_b, config_obj=None):
  1094. """
  1095. Helper function to recursively take the diff between two nested dictionaries. The resulting diff only contains the
  1096. values from `dict_a` that are different from values in `dict_b`.
  1097. dict_b : the default config dictionary. We want to remove values that are in this one
  1098. """
  1099. diff = {}
  1100. default = config_obj.__class__().to_dict() if config_obj is not None else {}
  1101. for key, value in dict_a.items():
  1102. obj_value = getattr(config_obj, str(key), None)
  1103. if isinstance(obj_value, PreTrainedConfig) and key in dict_b and isinstance(dict_b[key], dict):
  1104. diff_value = recursive_diff_dict(value, dict_b[key], config_obj=obj_value)
  1105. diff[key] = diff_value
  1106. elif key not in dict_b or (value != default[key]):
  1107. diff[key] = value
  1108. return diff
  1109. PreTrainedConfig.push_to_hub = copy_func(PreTrainedConfig.push_to_hub)
  1110. if PreTrainedConfig.push_to_hub.__doc__ is not None:
  1111. PreTrainedConfig.push_to_hub.__doc__ = PreTrainedConfig.push_to_hub.__doc__.format(
  1112. object="config", object_class="AutoConfig", object_files="configuration file"
  1113. )
  1114. # The alias is only here for BC - we did not have the correct CamelCasing before
  1115. PretrainedConfig = PreTrainedConfig
  1116. def layer_type_validation(layer_types: list[str], num_hidden_layers: int | None = None, attention: bool = True):
  1117. logger.warning(
  1118. "`layer_type_validation` is deprecated and will be removed in v5.20. "
  1119. "Use `PreTrainedConfig.validate_layer_type` instead"
  1120. )
  1121. if not all(layer_type in ALLOWED_LAYER_TYPES for layer_type in layer_types):
  1122. raise ValueError(f"The `layer_types` entries must be in {ALLOWED_LAYER_TYPES}")
  1123. if num_hidden_layers is not None and num_hidden_layers != len(layer_types):
  1124. raise ValueError(
  1125. f"`num_hidden_layers` ({num_hidden_layers}) must be equal to the number of layer types "
  1126. f"({len(layer_types)})"
  1127. )