configuration_utils.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # Copyright 2025 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import json
  16. import os
  17. from dataclasses import dataclass
  18. from typing import Any
  19. @dataclass
  20. class DistributedConfig:
  21. """
  22. Base class for distributed configs
  23. """
  24. enable_expert_parallel: bool = False
  25. # TODO: add tp_plan, pp_plan, device_mesh etc..
  26. @classmethod
  27. def from_dict(cls, config_dict, **kwargs):
  28. """
  29. Constructs a DistributedConfig instance from a dictionary of parameters.
  30. Args:
  31. config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
  32. **kwargs: Additional keyword arguments to override dictionary values.
  33. Returns:
  34. DistributedConfig: Instance of DistributedConfig constructed from the dictionary.
  35. """
  36. config = cls(**config_dict)
  37. to_remove = []
  38. for key, value in kwargs.items():
  39. if hasattr(config, key):
  40. setattr(config, key, value)
  41. to_remove.append(key)
  42. for key in to_remove:
  43. kwargs.pop(key, None)
  44. return config
  45. # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file
  46. def to_json_file(self, json_file_path: str | os.PathLike):
  47. """
  48. Save this instance to a JSON file.
  49. Args:
  50. json_file_path (`str` or `os.PathLike`):
  51. Path to the JSON file in which this configuration instance's parameters will be saved.
  52. use_diff (`bool`, *optional*, defaults to `True`):
  53. If set to `True`, only the difference between the config instance and the default
  54. `QuantizationConfig()` is serialized to JSON file.
  55. """
  56. with open(json_file_path, "w", encoding="utf-8") as writer:
  57. config_dict = self.to_dict()
  58. json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
  59. writer.write(json_string)
  60. def to_dict(self) -> dict[str, Any]:
  61. """
  62. Serializes this instance to a Python dictionary. Returns:
  63. `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
  64. """
  65. return copy.deepcopy(self.__dict__)
  66. # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__
  67. def __iter__(self):
  68. """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
  69. yield from copy.deepcopy(self.__dict__).items()
  70. # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__
  71. def __repr__(self):
  72. return f"{self.__class__.__name__} {self.to_json_string()}"
  73. def to_json_string(self):
  74. """
  75. Serializes this instance to a JSON formatted string.
  76. Returns:
  77. str: JSON formatted string representing the configuration instance.
  78. """
  79. return json.dumps(self.__dict__, indent=2) + "\n"
  80. def update(self, **kwargs):
  81. """
  82. Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes,
  83. returning all the unused kwargs.
  84. Args:
  85. kwargs (`Dict[str, Any]`):
  86. Dictionary of attributes to tentatively update this class.
  87. Returns:
  88. `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
  89. """
  90. to_remove = []
  91. for key, value in kwargs.items():
  92. if hasattr(self, key):
  93. setattr(self, key, value)
  94. to_remove.append(key)
  95. # Remove all the attributes that were updated, without modifying the input dict
  96. unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
  97. return unused_kwargs