_base.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Contains helpers to split tensors into shards."""
  15. from collections.abc import Callable
  16. from dataclasses import dataclass, field
  17. from typing import Any, TypeVar
  18. from .. import logging
  19. TensorT = TypeVar("TensorT")
  20. TensorSizeFn_T = Callable[[TensorT], int]
  21. StorageIDFn_T = Callable[[TensorT], Any | None]
  22. MAX_SHARD_SIZE = "5GB"
  23. SIZE_UNITS = {
  24. "TB": 10**12,
  25. "GB": 10**9,
  26. "MB": 10**6,
  27. "KB": 10**3,
  28. }
  29. logger = logging.get_logger(__file__)
  30. @dataclass
  31. class StateDictSplit:
  32. is_sharded: bool = field(init=False)
  33. metadata: dict[str, Any]
  34. filename_to_tensors: dict[str, list[str]]
  35. tensor_to_filename: dict[str, str]
  36. def __post_init__(self):
  37. self.is_sharded = len(self.filename_to_tensors) > 1
  38. def split_state_dict_into_shards_factory(
  39. state_dict: dict[str, TensorT],
  40. *,
  41. get_storage_size: TensorSizeFn_T,
  42. filename_pattern: str,
  43. get_storage_id: StorageIDFn_T = lambda tensor: None,
  44. max_shard_size: int | str = MAX_SHARD_SIZE,
  45. ) -> StateDictSplit:
  46. """
  47. Split a model state dictionary in shards so that each shard is smaller than a given size.
  48. The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization
  49. made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we
  50. have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not
  51. [6+2+2GB], [6+2GB], [6GB].
  52. > [!WARNING]
  53. > If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a
  54. > size greater than `max_shard_size`.
  55. Args:
  56. state_dict (`dict[str, Tensor]`):
  57. The state dictionary to save.
  58. get_storage_size (`Callable[[Tensor], int]`):
  59. A function that returns the size of a tensor when saved on disk in bytes.
  60. get_storage_id (`Callable[[Tensor], Optional[Any]]`, *optional*):
  61. A function that returns a unique identifier to a tensor storage. Multiple different tensors can share the
  62. same underlying storage. This identifier is guaranteed to be unique and constant for this tensor's storage
  63. during its lifetime. Two tensor storages with non-overlapping lifetimes may have the same id.
  64. filename_pattern (`str`, *optional*):
  65. The pattern to generate the files names in which the model will be saved. Pattern must be a string that
  66. can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
  67. max_shard_size (`int` or `str`, *optional*):
  68. The maximum size of each shard, in bytes. Defaults to 5GB.
  69. Returns:
  70. [`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them.
  71. """
  72. storage_id_to_tensors: dict[Any, list[str]] = {}
  73. shard_list: list[dict[str, TensorT]] = []
  74. current_shard: dict[str, TensorT] = {}
  75. current_shard_size = 0
  76. total_size = 0
  77. if isinstance(max_shard_size, str):
  78. max_shard_size = parse_size_to_int(max_shard_size)
  79. for key, tensor in state_dict.items():
  80. # when bnb serialization is used the weights in the state dict can be strings
  81. # check: https://github.com/huggingface/transformers/pull/24416 for more details
  82. if isinstance(tensor, str):
  83. logger.info("Skipping tensor %s as it is a string (bnb serialization)", key)
  84. continue
  85. # If a `tensor` shares the same underlying storage as another tensor, we put `tensor` in the same `block`
  86. storage_id = get_storage_id(tensor) # type: ignore[invalid-argument-type]
  87. if storage_id is not None:
  88. if storage_id in storage_id_to_tensors:
  89. # We skip this tensor for now and will reassign to correct shard later
  90. storage_id_to_tensors[storage_id].append(key)
  91. continue
  92. else:
  93. # This is the first tensor with this storage_id, we create a new entry
  94. # in the storage_id_to_tensors dict => we will assign the shard id later
  95. storage_id_to_tensors[storage_id] = [key]
  96. # Compute tensor size
  97. tensor_size = get_storage_size(tensor) # type: ignore[invalid-argument-type]
  98. # If this tensor is bigger than the maximal size, we put it in its own shard
  99. if tensor_size > max_shard_size:
  100. total_size += tensor_size
  101. shard_list.append({key: tensor})
  102. continue
  103. # If this tensor is going to tip up over the maximal size, we split.
  104. # Current shard already has some tensors, we add it to the list of shards and create a new one.
  105. if current_shard_size + tensor_size > max_shard_size:
  106. shard_list.append(current_shard)
  107. current_shard = {}
  108. current_shard_size = 0
  109. # Add the tensor to the current shard
  110. current_shard[key] = tensor
  111. current_shard_size += tensor_size
  112. total_size += tensor_size
  113. # Add the last shard
  114. if len(current_shard) > 0:
  115. shard_list.append(current_shard)
  116. nb_shards = len(shard_list)
  117. # Loop over the tensors that share the same storage and assign them together
  118. for storage_id, keys in storage_id_to_tensors.items():
  119. # Let's try to find the shard where the first tensor of this storage is and put all tensors in the same shard
  120. for shard in shard_list:
  121. if keys[0] in shard:
  122. for key in keys:
  123. shard[key] = state_dict[key]
  124. break
  125. # If we only have one shard, we return it => no need to build the index
  126. if nb_shards == 1:
  127. filename = filename_pattern.format(suffix="")
  128. return StateDictSplit(
  129. metadata={"total_size": total_size},
  130. filename_to_tensors={filename: list(state_dict.keys())},
  131. tensor_to_filename={key: filename for key in state_dict.keys()},
  132. )
  133. # Now that each tensor is assigned to a shard, let's assign a filename to each shard
  134. tensor_name_to_filename = {}
  135. filename_to_tensors = {}
  136. for idx, shard in enumerate(shard_list):
  137. filename = filename_pattern.format(suffix=f"-{idx + 1:05d}-of-{nb_shards:05d}")
  138. for key in shard:
  139. tensor_name_to_filename[key] = filename
  140. filename_to_tensors[filename] = list(shard.keys())
  141. # Build the index and return
  142. return StateDictSplit(
  143. metadata={"total_size": total_size},
  144. filename_to_tensors=filename_to_tensors,
  145. tensor_to_filename=tensor_name_to_filename,
  146. )
  147. def parse_size_to_int(size_as_str: str) -> int:
  148. """
  149. Parse a size expressed as a string with digits and unit (like `"5MB"`) to an integer (in bytes).
  150. Supported units are "TB", "GB", "MB", "KB".
  151. Args:
  152. size_as_str (`str`): The size to convert. Will be directly returned if an `int`.
  153. Example:
  154. ```py
  155. >>> parse_size_to_int("5MB")
  156. 5000000
  157. ```
  158. """
  159. size_as_str = size_as_str.strip()
  160. # Parse unit
  161. unit = size_as_str[-2:].upper()
  162. if unit not in SIZE_UNITS:
  163. raise ValueError(f"Unit '{unit}' not supported. Supported units are TB, GB, MB, KB. Got '{size_as_str}'.")
  164. multiplier = SIZE_UNITS[unit]
  165. # Parse value
  166. try:
  167. value = float(size_as_str[:-2].strip())
  168. except ValueError as e:
  169. raise ValueError(f"Could not parse the size value from '{size_as_str}': {e}") from e
  170. return int(value * multiplier)