| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- # Copyright 2025 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import copy
- import json
- import os
- from dataclasses import dataclass
- from typing import Any
- @dataclass
- class DistributedConfig:
- """
- Base class for distributed configs
- """
- enable_expert_parallel: bool = False
- # TODO: add tp_plan, pp_plan, device_mesh etc..
- @classmethod
- def from_dict(cls, config_dict, **kwargs):
- """
- Constructs a DistributedConfig instance from a dictionary of parameters.
- Args:
- config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
- **kwargs: Additional keyword arguments to override dictionary values.
- Returns:
- DistributedConfig: Instance of DistributedConfig constructed from the dictionary.
- """
- config = cls(**config_dict)
- to_remove = []
- for key, value in kwargs.items():
- if hasattr(config, key):
- setattr(config, key, value)
- to_remove.append(key)
- for key in to_remove:
- kwargs.pop(key, None)
- return config
- # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file
- def to_json_file(self, json_file_path: str | os.PathLike):
- """
- Save this instance to a JSON file.
- Args:
- json_file_path (`str` or `os.PathLike`):
- Path to the JSON file in which this configuration instance's parameters will be saved.
- use_diff (`bool`, *optional*, defaults to `True`):
- If set to `True`, only the difference between the config instance and the default
- `QuantizationConfig()` is serialized to JSON file.
- """
- with open(json_file_path, "w", encoding="utf-8") as writer:
- config_dict = self.to_dict()
- json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
- writer.write(json_string)
- def to_dict(self) -> dict[str, Any]:
- """
- Serializes this instance to a Python dictionary. Returns:
- `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
- """
- return copy.deepcopy(self.__dict__)
- # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__
- def __iter__(self):
- """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
- yield from copy.deepcopy(self.__dict__).items()
- # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__
- def __repr__(self):
- return f"{self.__class__.__name__} {self.to_json_string()}"
- def to_json_string(self):
- """
- Serializes this instance to a JSON formatted string.
- Returns:
- str: JSON formatted string representing the configuration instance.
- """
- return json.dumps(self.__dict__, indent=2) + "\n"
- def update(self, **kwargs):
- """
- Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes,
- returning all the unused kwargs.
- Args:
- kwargs (`Dict[str, Any]`):
- Dictionary of attributes to tentatively update this class.
- Returns:
- `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
- """
- to_remove = []
- for key, value in kwargs.items():
- if hasattr(self, key):
- setattr(self, key, value)
- to_remove.append(key)
- # Remove all the attributes that were updated, without modifying the input dict
- unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
- return unused_kwargs
|