accelerate.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921
  1. # Copyright 2025 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Some of the functions here are derived from the `accelerate` library, with some tweaks for better performances
  16. and simplicity/ease of use.
  17. """
  18. import copy
  19. import inspect
  20. import os
  21. import re
  22. from collections import OrderedDict, defaultdict
  23. from typing import TYPE_CHECKING
  24. from safetensors import safe_open
  25. from safetensors.torch import save_file
  26. from ..utils import (
  27. is_accelerate_available,
  28. is_torch_available,
  29. is_torch_xpu_available,
  30. logging,
  31. )
  32. from ..utils.quantization_config import QuantizationMethod
  33. from .deepspeed import is_deepspeed_zero3_enabled
  34. from .fsdp import is_fsdp_enabled
  35. if is_torch_available():
  36. import torch
  37. import torch.nn as nn
  38. if is_accelerate_available():
  39. from accelerate import dispatch_model
  40. from accelerate.utils import get_max_memory as accelerate_max_memory
  41. from accelerate.utils.modeling import clean_device_map, get_max_layer_size
  42. if TYPE_CHECKING:
  43. from ..modeling_utils import PreTrainedModel
  44. from ..quantizers import HfQuantizer
  45. logger = logging.get_logger(__name__)
  46. def get_module_size_with_ties(
  47. tied_params,
  48. module_size,
  49. module_sizes,
  50. modules_to_treat,
  51. ) -> tuple[int, list[str], list[nn.Module]]:
  52. """
  53. Calculate the total size of a module, including its tied parameters.
  54. Args:
  55. tied_params (`List[str]`): The list of tied parameters.
  56. module_size (`int`): The size of the module without tied parameters.
  57. module_sizes (`Dict[str, int]`): A dictionary mapping each layer name to its size.
  58. modules_to_treat (`List[Tuple[str, nn.Module]]`): The list of named modules to treat.
  59. Returns:
  60. `Tuple[int, List[str], List[nn.Module]]`: The total size of the module, the names of the tied modules, and the
  61. tied modules.
  62. """
  63. if len(tied_params) < 1:
  64. return module_size, [], []
  65. tied_module_names = []
  66. tied_modules = []
  67. for tied_param in tied_params:
  68. tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if tied_param.startswith(n + ".")][0]
  69. tied_module_names.append(modules_to_treat[tied_module_index][0])
  70. tied_modules.append(modules_to_treat[tied_module_index][1])
  71. module_size_with_ties = module_size
  72. for tied_param, tied_module_name in zip(tied_params, tied_module_names):
  73. module_size_with_ties += module_sizes[tied_module_name] - module_sizes[tied_param]
  74. return module_size_with_ties, tied_module_names, tied_modules
  75. def check_and_set_device_map(device_map: "torch.device | int | str | dict | None") -> dict | str | None:
  76. from ..modeling_utils import get_torch_context_manager_or_global_device
  77. # Potentially detect context manager or global device, and use it (only if no device_map was provided)
  78. if device_map is None and not is_deepspeed_zero3_enabled():
  79. device_in_context = get_torch_context_manager_or_global_device()
  80. if device_in_context == torch.device("meta"):
  81. raise RuntimeError(
  82. "You are using `from_pretrained` with a meta device context manager or `torch.set_default_device('meta')`.\n"
  83. "This is an anti-pattern as `from_pretrained` wants to load existing weights.\nIf you want to initialize an "
  84. "empty model on the meta device, use the context manager or global device with `from_config`, or `ModelClass(config)`"
  85. )
  86. device_map = device_in_context
  87. # change device_map into a map if we passed an int, a str or a torch.device
  88. if isinstance(device_map, torch.device):
  89. device_map = {"": device_map}
  90. elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
  91. try:
  92. if device_map == "cuda":
  93. # setting to the local rank
  94. local_rank = int(os.environ.get("LOCAL_RANK", 0))
  95. device_map = f"cuda:{local_rank}"
  96. device_map = {"": torch.device(device_map)}
  97. except RuntimeError:
  98. raise ValueError(
  99. "When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or "
  100. f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}."
  101. )
  102. elif isinstance(device_map, int):
  103. if device_map < 0:
  104. raise ValueError(
  105. "You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' "
  106. )
  107. else:
  108. device_map = {"": device_map}
  109. if device_map is not None:
  110. if is_deepspeed_zero3_enabled():
  111. raise ValueError("DeepSpeed Zero-3 is not compatible with passing a `device_map`.")
  112. if not is_accelerate_available():
  113. raise ValueError(
  114. "Using a `device_map`, `tp_plan`, `torch.device` context manager or setting `torch.set_default_device(device)` "
  115. "requires `accelerate`. You can install it with `pip install accelerate`"
  116. )
  117. return device_map
  118. def compute_module_sizes(
  119. model: "PreTrainedModel",
  120. hf_quantizer: "HfQuantizer | None" = None,
  121. buffers_only: bool = False,
  122. only_modules: bool = True,
  123. ) -> tuple[dict[str, int], dict[str, int]]:
  124. """
  125. Compute the size of each submodule of a given model (in bytes).
  126. Returns a tuple of 2 dicts, the fist one containing a mapping of all the modules and the corresponding size
  127. in bytes, and the 2nd one containing a mapping from all leaf modules (modules containing parameters, the end of
  128. the model graph) and the corresponding sizes.
  129. If `only_modules` is set to False, the first mapping will not only contain the size of all modules, but also
  130. the size of all parameters and buffers.
  131. """
  132. all_module_sizes = defaultdict(int)
  133. leaves_module_sizes = defaultdict(int)
  134. if buffers_only:
  135. iterator = model.named_buffers()
  136. else:
  137. # We need parameters + buffers here, as state_dict does not count non-persistent buffers which are taking space
  138. def all_tensors():
  139. yield from model.named_parameters()
  140. yield from model.named_buffers()
  141. iterator = all_tensors()
  142. tied_keys = getattr(model, "all_tied_weights_keys", {}).keys()
  143. for name, param in iterator:
  144. # Do not count tied keys (the model is usually not tied yet here, so they will appear in the iterator)
  145. # If the model is already tied, then they simply do not appear in the iterator anyway (remove_duplicates=True by default)
  146. if name in tied_keys:
  147. continue
  148. if hf_quantizer is not None:
  149. dtype_size = hf_quantizer.param_element_size(model, name, param)
  150. else:
  151. dtype_size = param.element_size()
  152. size = param.numel() * dtype_size
  153. name_parts = name.split(".")
  154. for idx in range(len(name_parts)):
  155. all_module_sizes[".".join(name_parts[:idx])] += size
  156. if "." in name:
  157. leaves_module_sizes[name.rsplit(".", 1)[0]] += size
  158. # If we want to also have the full leaves in `all_module_sizes`
  159. if not only_modules:
  160. all_module_sizes[name] += size
  161. return all_module_sizes, leaves_module_sizes
  162. def compute_module_total_buffer_size(model: nn.Module, hf_quantizer: "HfQuantizer | None" = None):
  163. """
  164. Compute the total size of buffers in each submodule of a given model.
  165. """
  166. module_sizes, _ = compute_module_sizes(model, hf_quantizer, buffers_only=True)
  167. return module_sizes.get("", 0)
  168. def get_max_memory(max_memory: dict[int | str, int | str] | None = None):
  169. """
  170. Get the maximum memory available if nothing is passed, converts string to int otherwise.
  171. Note: we need to overwrite this as accelerate does not take into account torch allocated but unused device memory...
  172. """
  173. # Get the max memory (it only uses free gpu memory, not torch allocated but free memory...)
  174. final_max_memory = accelerate_max_memory(max_memory)
  175. # Adjust for allocated but free memory
  176. for device_name in final_max_memory:
  177. if isinstance(device_name, int): # it's a GPU device
  178. # Only cuda and xpu use caching memory allocator
  179. if is_torch_xpu_available():
  180. unused_memory = torch.xpu.memory_reserved(device_name) - torch.xpu.memory_allocated(device_name)
  181. elif torch.cuda.is_available():
  182. unused_memory = torch.cuda.memory_reserved(device_name) - torch.cuda.memory_allocated(device_name)
  183. else:
  184. unused_memory = 0
  185. # Add the pre-allocated but unused device memory
  186. final_max_memory[device_name] += unused_memory
  187. # Still respect the `max_memory` passed by the user if any
  188. if max_memory is not None and device_name in max_memory:
  189. final_max_memory[device_name] = min(max_memory[device_name], final_max_memory[device_name])
  190. # If the user does not provide `max_memory`, accelerate sets the WHOLE cpu available memory as available.
  191. # This is unwanted, as we don't want to set extremely tight bound and pressure for cpu if we are memory-constrained,
  192. # especially if the model uses WeightConverter (because there will be some uncontrollable cpu memory spikes during
  193. # the conversions before we resave the weights). In those cases, it's better to offload to disk a bit more
  194. # if we were in-between, as otherwise we blow-up cpu memory
  195. if max_memory is None and "cpu" in final_max_memory:
  196. final_max_memory["cpu"] *= 0.90
  197. return final_max_memory
  198. def get_balanced_memory(
  199. model: "PreTrainedModel",
  200. max_memory: dict[int | str, int | str] | None = None,
  201. no_split_module_classes: set[str] | None = None,
  202. hf_quantizer: "HfQuantizer | None" = None,
  203. low_zero: bool = False,
  204. ):
  205. """
  206. Compute a `max_memory` dictionary for [`infer_auto_device_map`] that will balance the use of each available GPU.
  207. <Tip>
  208. All computation is done analyzing sizes and dtypes of the model parameters. As a result, the model can be on the
  209. meta device (as it would if initialized within the `init_empty_weights` context manager).
  210. </Tip>
  211. Args:
  212. model (`PreTrainedModel`):
  213. The model to analyze.
  214. max_memory (`Dict`, *optional*):
  215. A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset.
  216. Example: `max_memory={0: "1GB"}`.
  217. no_split_module_classes (`set[str]`, *optional*):
  218. A set of layer class names that should never be split across device (for instance any layer that has a
  219. residual connection).
  220. hf_quantizer (`HfQuantizer`, *optional*):
  221. A quantizer for the model.
  222. low_zero (`bool`, *optional*):
  223. Minimizes the number of weights on GPU 0, which is convenient when it's used for other operations (like the
  224. Transformers generate function).
  225. """
  226. # Get default / clean up max_memory
  227. user_not_set_max_memory = max_memory is None
  228. max_memory = get_max_memory(max_memory)
  229. # Check the number of accelerators available
  230. accelerator_max_memory = copy.deepcopy(max_memory)
  231. _, _ = accelerator_max_memory.pop("cpu", None), accelerator_max_memory.pop("disk", None)
  232. num_devices = len([d for d in accelerator_max_memory if accelerator_max_memory[d] > 0])
  233. if num_devices == 0:
  234. return max_memory
  235. if num_devices == 1:
  236. # We cannot do low_zero on just one GPU, but we will still reserve some memory for the buffer
  237. low_zero = False
  238. # If user just asked us to handle memory usage, we should avoid OOM
  239. if user_not_set_max_memory:
  240. for key in max_memory.keys():
  241. if isinstance(key, int):
  242. max_memory[key] *= 0.9 # 90% is a good compromise
  243. logger.info(
  244. f"We will use 90% of the memory on device {key} for storing the model, and 10% for the buffer to avoid OOM. "
  245. "You can set `max_memory` in to a higher value to use more memory (at your own risk)."
  246. )
  247. break # only one device
  248. module_sizes, leave_modules_sizes = compute_module_sizes(model, hf_quantizer)
  249. per_gpu = module_sizes[""] // (num_devices - 1 if low_zero else num_devices)
  250. # We can't just set the memory to model_size // num_devices as it will end being too small: each GPU will get
  251. # slightly less layers and some layers will end up offload at the end. So this function computes a buffer size to
  252. # add which is the biggest of:
  253. # - the size of no split block (if applicable)
  254. # - the mean of the layer sizes
  255. if no_split_module_classes is None:
  256. no_split_module_classes = []
  257. elif not isinstance(no_split_module_classes, (list, tuple, set)):
  258. no_split_module_classes = [no_split_module_classes]
  259. # Identify the size of the no_split_block modules
  260. buffer = 0
  261. if len(no_split_module_classes) > 0:
  262. no_split_children = {}
  263. for name, size in module_sizes.items():
  264. if name == "":
  265. continue
  266. submodule = model.get_submodule(name)
  267. class_name = submodule.__class__.__name__
  268. if class_name in no_split_module_classes and class_name not in no_split_children:
  269. no_split_children[class_name] = size
  270. if set(no_split_children.keys()) == set(no_split_module_classes):
  271. break
  272. buffer = max(no_split_children.values()) if len(no_split_children) > 0 else 0
  273. mean_leaves = int(sum(leave_modules_sizes.values()) / max(len(leave_modules_sizes), 1))
  274. buffer = int(1.25 * max(buffer, mean_leaves))
  275. per_gpu += buffer
  276. # Sorted list of GPUs id (we may have some gpu ids not included in the our max_memory list - let's ignore them)
  277. gpus_idx_list = sorted(
  278. device_id for device_id, device_mem in max_memory.items() if isinstance(device_id, int) and device_mem > 0
  279. )
  280. # The last device is left with max_memory just in case the buffer is not enough.
  281. for idx in gpus_idx_list[:-1]:
  282. max_memory[idx] = min(max_memory[0] if low_zero and idx == 0 else per_gpu, max_memory[idx])
  283. if low_zero:
  284. min_zero = max(0, module_sizes[""] - sum([max_memory[i] for i in range(1, num_devices)]))
  285. max_memory[0] = min(min_zero, max_memory[0])
  286. return max_memory
  287. def _get_device_map(
  288. model: "PreTrainedModel",
  289. device_map: dict | str | None,
  290. max_memory: dict | None,
  291. hf_quantizer: "HfQuantizer | None",
  292. ) -> dict:
  293. """Compute the final `device_map` to use if we passed a value in ['auto', 'balanced', 'balanced_low_0', 'sequential'].
  294. Otherwise, we check for any device inconsistencies in the device_map.
  295. """
  296. if isinstance(device_map, str):
  297. no_split_modules = model._no_split_modules
  298. if device_map != "sequential":
  299. inferred_max_memory = get_balanced_memory(
  300. model,
  301. max_memory=max_memory,
  302. no_split_module_classes=no_split_modules,
  303. hf_quantizer=hf_quantizer,
  304. low_zero=(device_map == "balanced_low_0"),
  305. )
  306. else:
  307. inferred_max_memory = get_max_memory(max_memory)
  308. if hf_quantizer is not None:
  309. inferred_max_memory = hf_quantizer.adjust_max_memory(inferred_max_memory)
  310. device_map = infer_auto_device_map(
  311. model,
  312. max_memory=inferred_max_memory,
  313. no_split_module_classes=no_split_modules,
  314. hf_quantizer=hf_quantizer,
  315. )
  316. if hf_quantizer is not None:
  317. hf_quantizer.validate_environment(device_map=device_map)
  318. return device_map
  319. def accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, offload_index, offload_buffers):
  320. device_map_kwargs = {
  321. "device_map": device_map,
  322. "offload_dir": offload_folder,
  323. "offload_index": offload_index,
  324. "offload_buffers": offload_buffers,
  325. }
  326. if "skip_keys" in inspect.signature(dispatch_model).parameters:
  327. device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
  328. # For HQQ method we force-set the hooks for single GPU envs
  329. if (
  330. "force_hooks" in inspect.signature(dispatch_model).parameters
  331. and hf_quantizer is not None
  332. and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ
  333. ):
  334. device_map_kwargs["force_hooks"] = True
  335. if (
  336. hf_quantizer is not None
  337. and hf_quantizer.quantization_config.quant_method == QuantizationMethod.FBGEMM_FP8
  338. and isinstance(device_map, dict)
  339. and ("cpu" in device_map.values() or "disk" in device_map.values())
  340. ):
  341. device_map_kwargs["offload_buffers"] = True
  342. if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
  343. dispatch_model(model, **device_map_kwargs)
  344. def expand_device_map(device_map: dict | None, param_names: list[str]):
  345. """
  346. Expand a device map to return the correspondence parameter name to device.
  347. """
  348. if device_map is None:
  349. return dict.fromkeys(param_names, "cpu")
  350. # Here, we first sort by number of submodules, then length of the full string, to make sure to match correctly
  351. device_map_regex = re.compile(
  352. "|".join(rf"({k})" for k in sorted(device_map.keys(), key=lambda x: (x.count("."), len(x)), reverse=True))
  353. )
  354. new_device_map = {}
  355. for param in param_names:
  356. device_match = device_map_regex.match(param)
  357. new_device_map[param] = device_map[device_match.group()] if device_match else device_map.get("", "cpu")
  358. return new_device_map
  359. def get_device(device_map: dict | None, param_name: str, valid_torch_device: bool = False) -> torch.device | str | int:
  360. """Return the device on which `param_name` should be according to the `device_map`. If `valid_torch_device` is `True`,
  361. then if the device is `"disk"`, `"cpu"` will be returned instead."""
  362. device = expand_device_map(device_map, [param_name])[param_name]
  363. if valid_torch_device and device == "disk":
  364. return "cpu"
  365. return device
  366. def accelerate_disk_offload(
  367. model: "PreTrainedModel",
  368. disk_offload_folder: str | None,
  369. checkpoint_files: list[str] | None,
  370. device_map: dict,
  371. sharded_metadata: dict | None,
  372. dtype: torch.dtype | None,
  373. weight_mapping=None,
  374. ):
  375. """
  376. Prepare the `disk_offload_index` that will be used for reading offloaded parameters. If reading from a safetensors
  377. file, parameters which do not need any special WeightConverter operation during loading (i.e. they are used as-is, or only
  378. renamed) will be mapped to where they already reside on disk. Otherwise, the parameters will be resaved inside
  379. `disk_offload_folder` during loading.
  380. """
  381. from ..core_model_loading import WeightRenaming, rename_source_key
  382. if disk_offload_folder is not None:
  383. os.makedirs(disk_offload_folder, exist_ok=True)
  384. is_offloaded_safetensors = checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors")
  385. renamings = []
  386. if weight_mapping is not None:
  387. renamings = [entry for entry in weight_mapping if isinstance(entry, WeightRenaming)]
  388. # In this case, the offload index is simply the existing safetensors (except if using custom weight loading
  389. # Operation, e.g. the MoE models, where we need to resave the weights that were changed at loading time)
  390. if is_offloaded_safetensors:
  391. meta_state_dict = model.state_dict()
  392. param_device_map = expand_device_map(device_map, meta_state_dict.keys())
  393. str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
  394. if sharded_metadata is None:
  395. weight_map = dict.fromkeys(safe_open(checkpoint_files[0], framework="pt").keys(), checkpoint_files[0])
  396. else:
  397. folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
  398. weight_map = {k: os.path.join(folder, v) for k, v in sharded_metadata["weight_map"].items()}
  399. # Update the weight names according to the `weight_mapping`
  400. weight_renaming_map = {
  401. rename_source_key(k, renamings, [], model.base_model_prefix, meta_state_dict)[0]: k for k in weight_map
  402. }
  403. # Prepare the index using existing safetensors files
  404. disk_offload_index = {
  405. target_name: {
  406. "safetensors_file": weight_map[source_name],
  407. "weight_name": source_name,
  408. "dtype": str_dtype,
  409. }
  410. for target_name, source_name in weight_renaming_map.items()
  411. # Need to check if it's in the mapping in case of unexpected keys that would result in KeyError (we skip them)
  412. if target_name in param_device_map and param_device_map[target_name] == "disk"
  413. }
  414. # In this case we will resave every offloaded weight
  415. else:
  416. disk_offload_index = {}
  417. return disk_offload_index
  418. def offload_weight(weight: torch.Tensor, weight_name: str, offload_folder: str | None, offload_index: dict) -> dict:
  419. """Write `weight` to disk inside `offload_folder`, and update `offload_index` accordingly. Everything is
  420. saved in `safetensors` format."""
  421. if offload_folder is None:
  422. raise ValueError(
  423. "The current `device_map` had weights offloaded to the disk, which needed to be re-saved. This is either "
  424. "because the weights are not in `safetensors` format, or because the model uses an internal weight format "
  425. "different than the one saved (i.e. most MoE models). Please provide an `offload_folder` for them in "
  426. "`from_pretrained`."
  427. )
  428. # Write the weight to disk
  429. safetensor_file = os.path.join(offload_folder, f"{weight_name}.safetensors")
  430. save_file({weight_name: weight}, safetensor_file)
  431. # Update the offloading index
  432. str_dtype = str(weight.dtype).replace("torch.", "")
  433. offload_index[weight_name] = {"safetensors_file": safetensor_file, "weight_name": weight_name, "dtype": str_dtype}
  434. return offload_index
  435. def load_offloaded_parameter(model: "PreTrainedModel", param_name: str) -> torch.Tensor:
  436. """Load `param_name` from disk, if it was offloaded due to the device_map, and thus lives as a meta parameter
  437. inside `model`.
  438. This is needed when resaving a model, when some parameters were offloaded (we need to load them from disk, to
  439. then resave them to disk in the correct shard...)."""
  440. # Start from the most inner module, and try to find the hook that was used for offloading the param
  441. module_parts = param_name.split(".")
  442. modules_to_check = [".".join(module_parts[:-idx]) for idx in range(1, len(module_parts))] + [""]
  443. for parent_name in modules_to_check:
  444. parent = model.get_submodule(parent_name)
  445. if hasattr(parent, "_hf_hook"):
  446. weights_map = parent._hf_hook.weights_map
  447. truncated_param_name = param_name.replace(f"{parent_name}." if parent_name != "" else parent_name, "")
  448. break
  449. # If we did not break the loop, something is wrong
  450. else:
  451. raise ValueError(
  452. f"{param_name} is on the meta device because it was offloaded, but we could not find "
  453. "the corresponding hook for it"
  454. )
  455. # This call loads it from disk
  456. tensor = weights_map[truncated_param_name]
  457. return tensor
  458. def _init_infer_auto_device_map(
  459. model: nn.Module,
  460. max_memory: dict[int | str, int | str] | None = None,
  461. no_split_module_classes: set[str] | None = None,
  462. tied_parameters: list[list[str]] | None = None,
  463. hf_quantizer: "HfQuantizer | None" = None,
  464. ) -> tuple[
  465. list[int | str],
  466. dict[int | str, int | str],
  467. list[int | str],
  468. list[int],
  469. dict[str, int],
  470. list[list[str]],
  471. list[str],
  472. list[tuple[str, nn.Module]],
  473. ]:
  474. """
  475. Initialize variables required for computing the device map for model allocation.
  476. """
  477. max_memory = get_max_memory(max_memory)
  478. if no_split_module_classes is None:
  479. no_split_module_classes = []
  480. elif not isinstance(no_split_module_classes, (list, tuple, set)):
  481. no_split_module_classes = [no_split_module_classes]
  482. devices = list(max_memory.keys())
  483. if "disk" not in devices:
  484. devices.append("disk")
  485. gpus = [device for device in devices if device not in ["cpu", "disk"]]
  486. # Devices that need to keep space for a potential offloaded layer.
  487. if "mps" in gpus:
  488. main_devices = ["mps"]
  489. elif len(gpus) > 0:
  490. main_devices = [gpus[0], "cpu"]
  491. else:
  492. main_devices = ["cpu"]
  493. module_sizes, _ = compute_module_sizes(model, hf_quantizer, only_modules=False)
  494. if tied_parameters is None:
  495. if len(model.all_tied_weights_keys) > 0:
  496. # create a list of list of tied params based on unique tied groups
  497. groups = set(model.all_tied_weights_keys.values())
  498. tied_parameters = [
  499. sorted([k for k, v in model.all_tied_weights_keys.items() if v == target] + [target])
  500. for target in groups
  501. ]
  502. else:
  503. tied_parameters = [[]]
  504. # Direct submodules and parameters
  505. modules_to_treat = (
  506. list(model.named_parameters(recurse=False))
  507. + list(model.named_children())
  508. + list(model.named_buffers(recurse=False))
  509. )
  510. return (
  511. devices,
  512. max_memory,
  513. main_devices,
  514. gpus,
  515. module_sizes,
  516. tied_parameters,
  517. no_split_module_classes,
  518. modules_to_treat,
  519. )
  520. def infer_auto_device_map(
  521. model: nn.Module,
  522. max_memory: dict[int | str, int | str] | None = None,
  523. no_split_module_classes: set[str] | None = None,
  524. verbose: bool = False,
  525. clean_result: bool = True,
  526. offload_buffers: bool = False,
  527. tied_parameters: list[list[str]] | None = None,
  528. hf_quantizer: "HfQuantizer | None" = None,
  529. ):
  530. """
  531. Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk,
  532. such that:
  533. - we don't exceed the memory available of any of the GPU.
  534. - if offload to the CPU is needed, there is always room left on GPU 0 to put back the layer offloaded on CPU that
  535. has the largest size.
  536. - if offload to the CPU is needed,we don't exceed the RAM available on the CPU.
  537. - if offload to the disk is needed, there is always room left on the CPU to put back the layer offloaded on disk
  538. that has the largest size.
  539. <Tip>
  540. All computation is done analyzing sizes and dtypes of the model parameters. As a result, the model can be on the
  541. meta device (as it would if initialized within the `init_empty_weights` context manager).
  542. </Tip>
  543. Args:
  544. model (`torch.nn.Module`):
  545. The model to analyze.
  546. max_memory (`Dict`, *optional*):
  547. A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset.
  548. Example: `max_memory={0: "1GB"}`.
  549. no_split_module_classes (`set[str]`, *optional*):
  550. A set of layer class names that should never be split across device (for instance any layer that has a
  551. residual connection).
  552. verbose (`bool`, *optional*, defaults to `False`):
  553. Whether or not to provide debugging statements as the function builds the device_map.
  554. clean_result (`bool`, *optional*, defaults to `True`):
  555. Clean the resulting device_map by grouping all submodules that go on the same device together.
  556. offload_buffers (`bool`, *optional*, defaults to `False`):
  557. In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as
  558. well as the parameters.
  559. """
  560. # Initialize the variables
  561. (
  562. devices,
  563. max_memory,
  564. main_devices,
  565. gpus,
  566. module_sizes,
  567. tied_parameters,
  568. no_split_module_classes,
  569. modules_to_treat,
  570. ) = _init_infer_auto_device_map(model, max_memory, no_split_module_classes, tied_parameters, hf_quantizer)
  571. device_map = OrderedDict()
  572. current_device = 0
  573. device_memory_used = dict.fromkeys(devices, 0)
  574. device_buffer_sizes = {}
  575. device_minimum_assignment_memory = {}
  576. # Initialize maximum largest layer, to know which space to keep in memory
  577. max_layer_size, max_layer_names = get_max_layer_size(modules_to_treat, module_sizes, no_split_module_classes)
  578. # Ready ? This is going to be a bit messy.
  579. while len(modules_to_treat) > 0:
  580. name, module = modules_to_treat.pop(0)
  581. if verbose:
  582. print(f"\nTreating module {name}.")
  583. # Max size in the remaining layers may have changed since we took one, so we maybe update it.
  584. max_layer_names = [n for n in max_layer_names if n != name and not n.startswith(name + ".")]
  585. if len(max_layer_names) == 0:
  586. max_layer_size, max_layer_names = get_max_layer_size(
  587. [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)],
  588. module_sizes,
  589. no_split_module_classes,
  590. )
  591. # Assess size needed
  592. module_size = module_sizes[name]
  593. # We keep relevant tied parameters only: one of the tied parameters in the group is inside the current module
  594. # and the other is not.
  595. # Note: If we are currently processing the name `compute.weight`, an other parameter named
  596. # e.g. `compute.weight_submodule.parameter`
  597. # needs to be considered outside the current module, hence the check with additional dots.
  598. tied_param_groups = [
  599. tied_group
  600. for tied_group in tied_parameters
  601. if any(name + "." in k + "." for k in tied_group) and not all(name + "." in k + "." for k in tied_group)
  602. ]
  603. if verbose and len(tied_param_groups) > 0:
  604. print(f" Found the relevant tied param groups {tied_param_groups}")
  605. # Then we keep track of all the parameters that are tied to the current module, but not in the current module
  606. tied_params = sum(
  607. [[p for p in tied_group if name + "." not in p + "."] for tied_group in tied_param_groups], []
  608. )
  609. if verbose and len(tied_params) > 0:
  610. print(f" So those parameters need to be taken into account {tied_params}")
  611. device = devices[current_device]
  612. current_max_size = max_memory[device] if device != "disk" else None
  613. current_memory_reserved = 0
  614. # Reduce max size available by the largest layer.
  615. if devices[current_device] in main_devices:
  616. current_max_size = current_max_size - max_layer_size
  617. current_memory_reserved = max_layer_size
  618. module_size_with_ties, tied_module_names, tied_modules = get_module_size_with_ties(
  619. tied_params, module_size, module_sizes, modules_to_treat
  620. )
  621. # The module and its tied modules fit on the current device.
  622. if current_max_size is None or device_memory_used[device] + module_size_with_ties <= current_max_size:
  623. if verbose:
  624. output = f"Putting {name}"
  625. if tied_module_names:
  626. output += f" and {tied_module_names}"
  627. else:
  628. output += f" (size={module_size})"
  629. if current_max_size is not None:
  630. output += f" (available={current_max_size - device_memory_used[device]})"
  631. output += f" on {device}."
  632. print(output)
  633. device_memory_used[device] += module_size_with_ties
  634. # Assign the primary module to the device.
  635. device_map[name] = device
  636. # Assign tied modules if any.
  637. for tied_module_name in tied_module_names:
  638. if tied_module_name in [m[0] for m in modules_to_treat]:
  639. # Find the index of the tied module in the list
  640. tied_module_index = next(i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name)
  641. # Remove the tied module from the list to prevent reprocessing
  642. modules_to_treat.pop(tied_module_index)
  643. # Assign the tied module to the device
  644. device_map[tied_module_name] = device
  645. # Buffer Handling
  646. if not offload_buffers and isinstance(module, nn.Module):
  647. # Compute the total buffer size for the module
  648. current_buffer_size = compute_module_total_buffer_size(module, hf_quantizer)
  649. # Update the buffer size on the device
  650. device_buffer_sizes[device] = device_buffer_sizes.get(device, 0) + current_buffer_size
  651. continue
  652. # The current module itself fits, so we try to split the tied modules.
  653. if len(tied_params) > 0 and device_memory_used[device] + module_size <= current_max_size:
  654. # can we split one of the tied modules to make it smaller or do we need to go on the next device?
  655. if verbose:
  656. print(
  657. f"Not enough space on {devices[current_device]} to put {name} and {tied_module_names} (space "
  658. f"available {current_max_size - device_memory_used[device]}, needed size {module_size_with_ties})."
  659. )
  660. split_happened = False
  661. for tied_module_name, tied_module in zip(tied_module_names, tied_modules):
  662. tied_module_children = list(tied_module.named_children())
  663. if len(tied_module_children) == 0 or tied_module.__class__.__name__ in no_split_module_classes:
  664. # can't break this one.
  665. continue
  666. if verbose:
  667. print(f"Splitting {tied_module_name}.")
  668. tied_module_children = list(tied_module.named_parameters(recurse=False)) + tied_module_children
  669. tied_module_children = [(f"{tied_module_name}.{n}", v) for n, v in tied_module_children]
  670. tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][0]
  671. modules_to_treat = (
  672. [(name, module)]
  673. + modules_to_treat[:tied_module_index]
  674. + tied_module_children
  675. + modules_to_treat[tied_module_index + 1 :]
  676. )
  677. # Update the max layer size.
  678. max_layer_size, max_layer_names = get_max_layer_size(
  679. [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)],
  680. module_sizes,
  681. no_split_module_classes,
  682. )
  683. split_happened = True
  684. break
  685. if split_happened:
  686. continue
  687. # If the tied module is not split, we go to the next device
  688. if verbose:
  689. print("None of the tied module can be split, going to the next device.")
  690. # The current module itself doesn't fit, so we have to split it or go to the next device.
  691. if device_memory_used[device] + module_size >= current_max_size:
  692. # Split or not split?
  693. modules_children = (
  694. []
  695. if isinstance(module, nn.Parameter) or isinstance(module, torch.Tensor)
  696. else list(module.named_children())
  697. )
  698. if verbose:
  699. print(
  700. f"Not enough space on {devices[current_device]} to put {name} (space available "
  701. f"{current_max_size - device_memory_used[device]}, module size {module_size})."
  702. )
  703. if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes:
  704. # -> no split, we go to the next device
  705. if verbose:
  706. print("This module cannot be split, going to the next device.")
  707. else:
  708. # -> split, we replace the module studied by its children + parameters
  709. if verbose:
  710. print(f"Splitting {name}.")
  711. modules_children = list(module.named_parameters(recurse=False)) + modules_children
  712. modules_to_treat = [(f"{name}.{n}", v) for n, v in modules_children] + modules_to_treat
  713. # Update the max layer size.
  714. max_layer_size, max_layer_names = get_max_layer_size(
  715. [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)],
  716. module_sizes,
  717. no_split_module_classes,
  718. )
  719. continue
  720. if device_memory_used[device] == 0:
  721. device_minimum_assignment_memory[device] = module_size_with_ties + current_memory_reserved
  722. # Neither the current module nor any tied modules can be split, so we move to the next device.
  723. device_memory_used[device] = device_memory_used[device] + current_memory_reserved
  724. current_device += 1
  725. modules_to_treat = [(name, module)] + modules_to_treat
  726. device_memory_used = {device: mem for device, mem in device_memory_used.items() if mem > 0}
  727. if clean_result:
  728. device_map = clean_device_map(device_map)
  729. non_gpu_buffer_size = device_buffer_sizes.get("cpu", 0) + device_buffer_sizes.get("disk", 0)
  730. if non_gpu_buffer_size > 0 and not offload_buffers:
  731. is_buffer_fit_any_gpu = False
  732. for gpu_device, gpu_max_memory in max_memory.items():
  733. if gpu_device == "cpu" or gpu_device == "disk":
  734. continue
  735. if not is_buffer_fit_any_gpu:
  736. gpu_memory_used = device_memory_used.get(gpu_device, 0)
  737. if gpu_max_memory >= non_gpu_buffer_size + gpu_memory_used:
  738. is_buffer_fit_any_gpu = True
  739. if len(gpus) > 0 and not is_buffer_fit_any_gpu:
  740. logger.warning(
  741. f"Current model requires {non_gpu_buffer_size} bytes of buffer for offloaded layers, which seems does "
  742. f"not fit any GPU's remaining memory. If you are experiencing a OOM later, please consider using "
  743. f"offload_buffers=True."
  744. )
  745. if device_minimum_assignment_memory:
  746. devices_info = "\n".join(
  747. f" - {device}: {mem} bytes required" for device, mem in device_minimum_assignment_memory.items()
  748. )
  749. logger.info(
  750. f"Based on the current allocation process, no modules could be assigned to the following devices due to "
  751. f"insufficient memory:\n"
  752. f"{devices_info}\n"
  753. f"These minimum requirements are specific to this allocation attempt and may vary. Consider increasing "
  754. f"the available memory for these devices to at least the specified minimum, or adjusting the model config."
  755. )
  756. check_tied_parameters_on_same_device(tied_parameters, device_map)
  757. return device_map
  758. def _get_param_device(param, device_map):
  759. if param in device_map:
  760. return device_map[param]
  761. parent_param = ".".join(param.split(".")[:-1])
  762. if parent_param == param:
  763. raise ValueError(f"The `device_map` does not contain the module {param}.")
  764. else:
  765. return _get_param_device(parent_param, device_map)
  766. def check_tied_parameters_on_same_device(tied_params, device_map):
  767. """
  768. Check if tied parameters are on the same device
  769. Args:
  770. tied_params (`List[List[str]]`):
  771. A list of lists of parameter names being all tied together.
  772. device_map (`Dict[str, Union[int, str, torch.device]]`):
  773. A map that specifies where each submodule should go.
  774. """
  775. for tie_param in tied_params:
  776. tie_param_devices = {}
  777. for param in tie_param:
  778. tie_param_devices[param] = _get_param_device(param, device_map)
  779. if len(set(tie_param_devices.values())) > 1:
  780. logger.warning(
  781. f"Tied parameters are on different devices: {tie_param_devices}. "
  782. "Please modify your custom device map or set `device_map='auto'`. "
  783. )