deepspeed.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733
  1. # Copyright 2020 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Integration with Deepspeed
  16. """
  17. import copy
  18. import importlib.metadata
  19. import importlib.util
  20. import weakref
  21. from functools import partialmethod
  22. from ..dependency_versions_check import dep_version_check
  23. from ..utils import is_accelerate_available, is_torch_available, logging
  24. if is_torch_available():
  25. import torch
  26. from torch import nn
  27. logger = logging.get_logger(__name__)
  28. def is_deepspeed_available():
  29. package_exists = importlib.util.find_spec("deepspeed") is not None
  30. # Check we're not importing a "deepspeed" directory somewhere but the actual library by trying to grab the version
  31. # AND checking it has an author field in the metadata that is HuggingFace.
  32. if package_exists:
  33. try:
  34. _ = importlib.metadata.metadata("deepspeed")
  35. return True
  36. except importlib.metadata.PackageNotFoundError:
  37. return False
  38. if is_accelerate_available() and is_deepspeed_available():
  39. from accelerate.utils.deepspeed import HfDeepSpeedConfig as DeepSpeedConfig
  40. else:
  41. # Inherits from a dummy `object` if accelerate is not available, so that python succeeds to import this file.
  42. # Deepspeed glue code will never inherit this dummy object as it checks if accelerate is available.
  43. from builtins import object as DeepSpeedConfig
  44. class HfDeepSpeedConfig(DeepSpeedConfig): # noqa UP004
  45. """
  46. This object contains a DeepSpeed configuration dictionary and can be quickly queried for things like zero stage.
  47. A `weakref` of this object is stored in the module's globals to be able to access the config from areas where
  48. things like the Trainer object is not available (e.g. `from_pretrained` and `_get_resized_embeddings`). Therefore
  49. it's important that this object remains alive while the program is still running.
  50. [`Trainer`] uses the `HfTrainerDeepSpeedConfig` subclass instead. That subclass has logic to sync the configuration
  51. with values of [`TrainingArguments`] by replacing special placeholder values: `"auto"`. Without this special logic
  52. the DeepSpeed configuration is not modified in any way.
  53. Args:
  54. config_file_or_dict (`Union[str, Dict]`): path to DeepSpeed config file or dict.
  55. """
  56. def __init__(self, config_file_or_dict):
  57. # set global weakref object
  58. set_hf_deepspeed_config(self)
  59. dep_version_check("accelerate")
  60. dep_version_check("deepspeed")
  61. super().__init__(config_file_or_dict)
  62. class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig):
  63. """
  64. The `HfTrainerDeepSpeedConfig` object is meant to be created during `TrainingArguments` object creation and has the
  65. same lifespan as the latter.
  66. """
  67. def __init__(self, config_file_or_dict):
  68. super().__init__(config_file_or_dict)
  69. self._dtype = None
  70. self.mismatches = []
  71. def dtype(self):
  72. if self._dtype is None:
  73. raise ValueError("trainer_config_process() wasn't called yet to tell dtype")
  74. return self._dtype
  75. def is_auto(self, ds_key_long):
  76. val = self.get_value(ds_key_long)
  77. if val is None:
  78. return False
  79. else:
  80. return val == "auto"
  81. def fill_match(self, ds_key_long, hf_val, hf_key=None, must_match=True):
  82. """
  83. A utility method that massages the config file and can optionally verify that the values match.
  84. 1. Replace "auto" values with `TrainingArguments` value.
  85. 2. If it wasn't "auto" and `must_match` is true, then check that DS config matches Trainer
  86. config values and if mismatched add the entry to `self.mismatched` - will assert during
  87. `trainer_config_finalize` for one or more mismatches.
  88. """
  89. config, ds_key = self.find_config_node(ds_key_long)
  90. if config is None:
  91. return
  92. if config.get(ds_key) == "auto":
  93. config[ds_key] = hf_val
  94. return
  95. if not must_match:
  96. return
  97. ds_val = config.get(ds_key)
  98. if ds_val is not None and ds_val != hf_val:
  99. self.mismatches.append(f"- ds {ds_key_long}={ds_val} vs hf {hf_key}={hf_val}")
  100. fill_only = partialmethod(fill_match, must_match=False)
  101. def trainer_config_process(self, args, auto_find_batch_size=False):
  102. """
  103. Adjust the config with `TrainingArguments` values. This stage is run during `TrainingArguments` object
  104. creation.
  105. """
  106. # DeepSpeed does:
  107. # train_batch_size = world_size * train_micro_batch_size_per_gpu * gradient_accumulation_steps
  108. train_batch_size = args.world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps
  109. self.fill_match(
  110. "train_micro_batch_size_per_gpu",
  111. args.per_device_train_batch_size,
  112. "per_device_train_batch_size",
  113. not auto_find_batch_size,
  114. )
  115. self.fill_match(
  116. "gradient_accumulation_steps",
  117. args.gradient_accumulation_steps,
  118. "gradient_accumulation_steps",
  119. )
  120. self.fill_match(
  121. "train_batch_size",
  122. train_batch_size,
  123. "train_batch_size (calculated)",
  124. not auto_find_batch_size,
  125. )
  126. self.fill_match("gradient_clipping", args.max_grad_norm, "max_grad_norm")
  127. self.fill_match("optimizer.params.lr", args.learning_rate, "learning_rate")
  128. self.fill_match(
  129. "optimizer.params.betas",
  130. [args.adam_beta1, args.adam_beta2],
  131. "adam_beta1+adam_beta2",
  132. )
  133. self.fill_match("optimizer.params.eps", args.adam_epsilon, "adam_epsilon")
  134. self.fill_match("optimizer.params.weight_decay", args.weight_decay, "weight_decay")
  135. self.fill_only("scheduler.params.warmup_min_lr", 0) # not a trainer arg
  136. self.fill_match("scheduler.params.warmup_max_lr", args.learning_rate, "learning_rate")
  137. # total_num_steps - will get set in trainer_config_finalize
  138. if args.save_on_each_node:
  139. # deepspeed uses shared storage by default. Let's override this setting if save_on_each_node == True
  140. self.config["checkpoint"] = self.config.get("checkpoint", {})
  141. self.config["checkpoint"]["use_node_local_storage"] = args.save_on_each_node
  142. # amp: similar to the pytorch native amp - it has a bunch of optional params but we won't set
  143. # any here unless the user did the work
  144. self.fill_match("fp16.enabled", (args.fp16 or args.fp16_full_eval), "fp16|fp16_full_eval")
  145. self.fill_match("bf16.enabled", (args.bf16 or args.bf16_full_eval), "bf16|bf16_full_eval")
  146. # deepspeed's default mode is fp16 unless there is a config that says differently
  147. if self.is_true("bf16.enabled"):
  148. self._dtype = torch.bfloat16
  149. elif self.is_true("fp16.enabled"):
  150. self._dtype = torch.float16
  151. else:
  152. self._dtype = torch.float32
  153. def trainer_config_finalize(self, args, model, num_training_steps):
  154. """
  155. This stage is run after we have the model and know num_training_steps.
  156. Now we can complete the configuration process.
  157. """
  158. # zero
  159. # deal with config keys that use `auto` value and rely on model's hidden_size
  160. hidden_size_based_keys = [
  161. "zero_optimization.reduce_bucket_size",
  162. "zero_optimization.stage3_prefetch_bucket_size",
  163. "zero_optimization.stage3_param_persistence_threshold",
  164. ]
  165. hidden_size_auto_keys = [x for x in hidden_size_based_keys if self.is_auto(x)]
  166. if len(hidden_size_auto_keys) > 0:
  167. hidden_size = None
  168. if hasattr(model, "config"):
  169. if hasattr(model.config, "hidden_size"):
  170. hidden_size = model.config.hidden_size
  171. elif hasattr(model.config, "hidden_sizes"):
  172. # if there are many hidden sizes pick the largest one
  173. hidden_size = max(model.config.hidden_sizes)
  174. elif hasattr(model.config, "text_config") and hasattr(model.config.text_config, "hidden_size"):
  175. hidden_size = model.config.text_config.hidden_size
  176. elif hasattr(model.config, "text_config") and hasattr(model.config.text_config, "hidden_sizes"):
  177. # if there are many hidden sizes pick the largest one
  178. hidden_size = max(model.config.text_config.hidden_sizes)
  179. if hidden_size is None:
  180. raise ValueError(
  181. "The model's config file has neither `hidden_size` nor `hidden_sizes` entry, "
  182. "therefore it's not possible to automatically fill out the following `auto` entries "
  183. f"in the DeepSpeed config file: {hidden_size_auto_keys}. You can fix that by replacing "
  184. "`auto` values for these keys with an integer value of your choice."
  185. )
  186. self.fill_only("zero_optimization.reduce_bucket_size", hidden_size * hidden_size)
  187. if self.is_zero3():
  188. # automatically assign the optimal config values based on model config
  189. self.fill_only(
  190. "zero_optimization.stage3_prefetch_bucket_size",
  191. int(0.9 * hidden_size * hidden_size),
  192. )
  193. self.fill_only(
  194. "zero_optimization.stage3_param_persistence_threshold",
  195. 10 * hidden_size,
  196. )
  197. # scheduler
  198. self.fill_match(
  199. "scheduler.params.total_num_steps",
  200. num_training_steps,
  201. "num_training_steps (calculated)",
  202. )
  203. self.fill_match(
  204. "scheduler.params.warmup_num_steps",
  205. args.get_warmup_steps(num_training_steps),
  206. "warmup_steps",
  207. )
  208. if len(self.mismatches) > 0:
  209. mismatches = "\n".join(self.mismatches)
  210. raise ValueError(
  211. "Please correct the following DeepSpeed config values that mismatch TrainingArguments"
  212. f" values:\n{mismatches}\nThe easiest method is to set these DeepSpeed config values to 'auto'."
  213. )
  214. # keep the config object global to be able to access it anywhere during TrainingArguments life-cycle
  215. _hf_deepspeed_config_weak_ref = None
  216. def set_hf_deepspeed_config(hf_deepspeed_config_obj):
  217. # this is a special weakref global object to allow us to get to Deepspeed config from APIs
  218. # that don't have an easy way to get to the Deepspeed config outside of the Trainer domain.
  219. global _hf_deepspeed_config_weak_ref
  220. # will go away automatically when HfDeepSpeedConfig is destroyed (when TrainingArguments is destroyed)
  221. _hf_deepspeed_config_weak_ref = weakref.ref(hf_deepspeed_config_obj)
  222. def unset_hf_deepspeed_config():
  223. # useful for unit tests to ensure the global state doesn't leak - call from `tearDown` method
  224. global _hf_deepspeed_config_weak_ref
  225. _hf_deepspeed_config_weak_ref = None
  226. def is_deepspeed_zero3_enabled():
  227. if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None:
  228. return _hf_deepspeed_config_weak_ref().is_zero3()
  229. else:
  230. return False
  231. def deepspeed_config():
  232. if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None:
  233. return _hf_deepspeed_config_weak_ref().config
  234. else:
  235. return None
  236. def initialize_weights_zero3(model):
  237. """
  238. DeepSpeed ZeRO-3 variant of `PreTrainedModel.initialize_weights`. Mirrors the `smart_apply`
  239. dispatch logic but gathers each module's partitioned parameters before calling
  240. `_initialize_weights`, so initialization operates on full tensors instead of empty shards.
  241. Only rank 0 performs the actual init.
  242. """
  243. import deepspeed
  244. import torch
  245. from ..initialization import guard_torch_init_functions
  246. from ..modeling_utils import PreTrainedModel
  247. is_remote_code = model.is_remote_code()
  248. def _apply_zero3(model_or_module, fn):
  249. for child in model_or_module.children():
  250. if isinstance(child, PreTrainedModel):
  251. _apply_zero3(child, child._initialize_weights)
  252. else:
  253. _apply_zero3(child, fn)
  254. params = list(model_or_module.parameters(recurse=False))
  255. if params:
  256. with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
  257. if deepspeed.comm.get_rank() == 0:
  258. fn(model_or_module, is_remote_code)
  259. else:
  260. fn(model_or_module, is_remote_code)
  261. with torch.no_grad():
  262. with guard_torch_init_functions():
  263. _apply_zero3(model, model._initialize_weights)
  264. def _apply_weight_conversions_to_state_dict(model, state_dict, weight_mapping):
  265. """
  266. Apply weight conversions (renaming and merging/splitting operations) to a state dict.
  267. This is a simplified version that handles the conversion without loading into the model.
  268. """
  269. # Check for Tensor Parallelism - weight conversions are not tested with TP
  270. # TP uses ReplaceWithTensorSlicing which may conflict with our weight conversions
  271. ds_config = deepspeed_config()
  272. if ds_config is not None:
  273. # Check training config (tensor_parallel.autotp_size)
  274. tp_size = ds_config.get("tensor_parallel", {}).get("autotp_size", 1)
  275. # Check inference config (inference.tensor_parallel.tp_size)
  276. inference_config = ds_config.get("inference", {})
  277. if isinstance(inference_config, dict):
  278. tp_size = max(tp_size, inference_config.get("tensor_parallel", {}).get("tp_size", 1))
  279. if tp_size > 1:
  280. raise NotImplementedError(
  281. "Weight conversions (e.g., MoE expert fusion) with DeepSpeed Tensor Parallelism "
  282. "are not yet implemented but support is coming soon. Please disable tensor_parallel "
  283. "in your DeepSpeed config or convert your checkpoint to the expected format first."
  284. )
  285. from ..core_model_loading import WeightConverter, WeightRenaming, dot_natural_key, rename_source_key
  286. # Preserve metadata from the original state dict
  287. metadata = getattr(state_dict, "_metadata", None)
  288. prefix = model.base_model_prefix
  289. # Build a meta state dict for matching - only keys/shapes, no actual tensor data
  290. # This minimizes memory since we don't duplicate the model's parameters
  291. model_state_dict = {}
  292. for key, param in model.state_dict().items():
  293. model_state_dict[key] = torch.empty(param.shape, dtype=param.dtype, device="meta")
  294. renamings = [entry for entry in weight_mapping if isinstance(entry, WeightRenaming)]
  295. converters = [entry for entry in weight_mapping if isinstance(entry, WeightConverter)]
  296. # Fast path: if we only have simple renamings and no converters, we can skip the expensive collection logic
  297. if len(converters) == 0:
  298. new_state_dict = {}
  299. for original_key, tensor in state_dict.items():
  300. renamed_key, _ = rename_source_key(original_key, renamings, [], prefix, model_state_dict)
  301. if renamed_key in model_state_dict:
  302. new_state_dict[renamed_key] = tensor
  303. # Attach metadata to the new state dict
  304. if metadata is not None:
  305. new_state_dict._metadata = metadata
  306. return new_state_dict
  307. # Full path: we have WeightConverter operations that require tensor fusion/splitting
  308. pattern_to_converter = {k: converter for converter in converters for k in converter.source_patterns}
  309. # Build a mapping of what needs to be converted
  310. # Sort keys to ensure consistent ordering (important for MoE conversions)
  311. # Iterate over sorted keys and pop from state_dict to free memory immediately
  312. conversion_mapping = {}
  313. new_state_dict = {}
  314. sorted_keys = sorted(state_dict.keys(), key=lambda k: dot_natural_key(k))
  315. for original_key in sorted_keys:
  316. tensor = state_dict.pop(original_key)
  317. renamed_key, source_pattern = rename_source_key(original_key, renamings, converters, prefix, model_state_dict)
  318. # Only process if the renamed key is in the model's state dict
  319. if renamed_key in model_state_dict:
  320. # If source_pattern is not None, this key needs WeightConverter (e.g., MoE fusion)
  321. if source_pattern is not None:
  322. # Create a fresh converter for this layer to hold its tensors
  323. # Share operations list (lightweight, no large data) but get new collected_tensors
  324. converter = pattern_to_converter[source_pattern]
  325. new_converter = WeightConverter(
  326. source_patterns=converter.source_patterns,
  327. target_patterns=converter.target_patterns,
  328. operations=converter.operations,
  329. )
  330. mapping = conversion_mapping.setdefault(renamed_key, new_converter)
  331. mapping.add_tensor(renamed_key, original_key, source_pattern, tensor)
  332. else:
  333. # No conversion needed - add tensor directly to new_state_dict
  334. # (this handles keys like embed_tokens, lm_head, layernorm, attention)
  335. new_state_dict[renamed_key] = tensor
  336. # Apply the conversions and build the new state dict
  337. for renamed_key, mapping in conversion_mapping.items():
  338. try:
  339. realized_value = mapping.convert(
  340. renamed_key,
  341. model=model,
  342. config=model.config,
  343. )
  344. for target_name, param in realized_value.items():
  345. param = param[0] if isinstance(param, list) else param
  346. new_state_dict[target_name] = param
  347. except Exception as e:
  348. raise RuntimeError(
  349. f"Failed to apply weight conversion for '{renamed_key}'. "
  350. f"This likely means the checkpoint format is incompatible with the current model version. "
  351. f"Error: {e}"
  352. ) from e
  353. # Attach metadata to the new state dict
  354. if metadata is not None:
  355. new_state_dict._metadata = metadata
  356. return new_state_dict
  357. def _load_state_dict_into_zero3_model(model_to_load, state_dict, load_config=None):
  358. """
  359. Loads state dict into a model specifically for Zero3, since DeepSpeed does not support the `transformers`
  360. tensor parallelism API.
  361. Nearly identical code to PyTorch's `_load_from_state_dict`
  362. Args:
  363. model_to_load: The model to load weights into
  364. state_dict: The state dict containing the weights
  365. load_config: Optional LoadStateDictConfig containing weight_mapping and other loading options
  366. """
  367. # copy state_dict so `_load_state_dict_into_zero3_model` can modify it
  368. metadata = getattr(state_dict, "_metadata", None)
  369. state_dict = state_dict.copy()
  370. if metadata is not None:
  371. state_dict._metadata = metadata
  372. # Extract weight_mapping from load_config if provided
  373. weight_mapping = None
  374. if load_config is not None:
  375. weight_mapping = getattr(load_config, "weight_mapping", None)
  376. # Apply weight conversions if provided
  377. if weight_mapping is not None and len(weight_mapping) > 0:
  378. state_dict = _apply_weight_conversions_to_state_dict(model_to_load, state_dict, weight_mapping)
  379. # Keep the current weight conversion mapping for later saving (in case it was coming directly from the user)
  380. model_to_load._weight_conversions = weight_mapping
  381. error_msgs = []
  382. meta_model_state_dict = model_to_load.state_dict()
  383. missing_keys = set(meta_model_state_dict.keys())
  384. prefix_model = getattr(model_to_load, "base_model_prefix", None)
  385. # take care of the case where in the checkpoint we don't have the prefix
  386. state_dict = {
  387. (f"{prefix_model}.{k}" if meta_model_state_dict.get(f"{prefix_model}.{k}") is not None else k): v
  388. for k, v in state_dict.items()
  389. }
  390. # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
  391. # so we need to apply the function recursively.
  392. def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False):
  393. local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
  394. local_metadata["assign_to_params_buffers"] = assign_to_params_buffers
  395. args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
  396. # Parameters of module and children will start with prefix. We can exit early if there are none in this
  397. # state_dict
  398. if is_deepspeed_zero3_enabled():
  399. import deepspeed
  400. # In sharded models, each shard has only part of the full state_dict, so only gather
  401. # parameters that are in the current state_dict.
  402. named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
  403. params_to_gather = []
  404. for k in named_parameters:
  405. if k in state_dict:
  406. param = named_parameters[k]
  407. # crutial to not init the weight again
  408. param._is_hf_initialized = True
  409. params_to_gather.append(param)
  410. missing_keys.discard(k)
  411. if len(params_to_gather) > 0:
  412. # because zero3 puts placeholders in model params, this context
  413. # manager gathers (unpartitions) the params of the current layer, then loads from
  414. # the state dict and then re-partitions them again
  415. with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
  416. if torch.distributed.get_rank() == 0:
  417. module._load_from_state_dict(*args)
  418. for name, child in module._modules.items():
  419. if child is not None:
  420. load(child, state_dict, prefix + name + ".", assign_to_params_buffers)
  421. load(model_to_load, state_dict, assign_to_params_buffers=False)
  422. return error_msgs, missing_keys
  423. def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps, model_parameters):
  424. """
  425. A convenience wrapper that deals with optimizer and lr scheduler configuration.
  426. """
  427. from accelerate.utils import DummyOptim, DummyScheduler
  428. config = hf_deepspeed_config.config
  429. # Mixing and matching DS schedulers and optimizers is supported unless Offload is enabled in which case it's:
  430. # 1. DS scheduler + DS optimizer: Yes
  431. # 2. HF scheduler + HF optimizer: Mostly*
  432. # 3. DS scheduler + HF optimizer: Mostly*
  433. # 4. HF scheduler + DS optimizer: Yes
  434. #
  435. # Mostly*: All non-native DeepSpeed optimizers that have both CPU and GPU implementation should work (except LAMB)
  436. optimizer = None
  437. if "optimizer" in config:
  438. optimizer = DummyOptim(params=model_parameters)
  439. else:
  440. if hf_deepspeed_config.is_offload():
  441. logger.info(
  442. "Detected ZeRO Offload and non-DeepSpeed optimizers: This combination should work as long as the"
  443. " custom optimizer has both CPU and GPU implementation (except LAMB)"
  444. )
  445. # ds supports Adam, OneBitAdam, and Lamb optimizers and can import other optimizers from torch.
  446. # But trainer uses AdamW by default.
  447. optimizer = trainer.create_optimizer()
  448. # To use other optimizers requires voiding warranty with: `zero_allow_untested_optimizer`
  449. config["zero_allow_untested_optimizer"] = True
  450. lr_scheduler = None
  451. if "scheduler" in config:
  452. lr_scheduler = DummyScheduler(optimizer)
  453. else:
  454. if isinstance(optimizer, DummyOptim):
  455. def _lr_scheduler_callable(optimizer):
  456. # create a shallow copy first, so later modifications do not affect original trainer
  457. trainer_copy = copy.copy(trainer)
  458. # at the time _lr_scheduler_callable is called, trainer.lr_scheduler has been set
  459. # update it to None so that we can re-create a new scheduler
  460. trainer_copy.lr_scheduler = None
  461. lr_scheduler = trainer_copy.create_scheduler(
  462. num_training_steps=num_training_steps, optimizer=optimizer
  463. )
  464. return lr_scheduler
  465. lr_scheduler = DummyScheduler(optimizer, lr_scheduler_callable=_lr_scheduler_callable)
  466. return optimizer, lr_scheduler
  467. def deepspeed_init(trainer, num_training_steps, inference=False):
  468. """
  469. Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args.
  470. If `resume_from_checkpoint` was passed then an attempt to resume from a previously saved checkpoint will be made.
  471. Args:
  472. trainer: Trainer object
  473. num_training_steps: per single gpu
  474. resume_from_checkpoint: path to a checkpoint if to resume from after normal DeepSpeedEngine load
  475. inference: launch in inference mode (no optimizer and no lr scheduler)
  476. auto_find_batch_size: whether to ignore the `train_micro_batch_size_per_gpu` argument as it's being
  477. set automatically by the auto batch size finder
  478. Returns: optimizer, lr_scheduler
  479. We may use `deepspeed_init` more than once during the life of Trainer, when we do - it's a temp hack based on:
  480. https://github.com/deepspeedai/DeepSpeed/issues/1394#issuecomment-937405374 until Deepspeed fixes a bug where it
  481. can't resume from a checkpoint after it did some stepping https://github.com/deepspeedai/DeepSpeed/issues/1612
  482. """
  483. from deepspeed.utils import logger as ds_logger
  484. model = trainer.model
  485. args = trainer.args
  486. hf_deepspeed_config = trainer.accelerator.state.deepspeed_plugin.hf_ds_config
  487. # resume config update - some bits like `model` and `num_training_steps` only become available during train
  488. hf_deepspeed_config.trainer_config_finalize(args, model, num_training_steps)
  489. # set the Deepspeed log level consistent with the Trainer
  490. ds_logger.setLevel(args.get_process_log_level())
  491. if inference:
  492. # only Z3 makes sense for the inference
  493. if not hf_deepspeed_config.is_zero3():
  494. raise ValueError("ZeRO inference only makes sense with ZeRO Stage 3 - please adjust your config")
  495. # in case the training config is re-used for inference
  496. hf_deepspeed_config.del_config_sub_tree("optimizer")
  497. hf_deepspeed_config.del_config_sub_tree("lr_scheduler")
  498. optimizer, lr_scheduler = None, None
  499. model_parameters = None
  500. else:
  501. trainer.optimizer = None # important for when deepspeed_init is used as re-init
  502. deepspeed_tp_size = hf_deepspeed_config.config.get("tensor_parallel", {}).get("autotp_size", 1)
  503. if deepspeed_tp_size > 1:
  504. import deepspeed
  505. model = deepspeed.tp_model_init(
  506. model=model,
  507. tp_size=deepspeed_tp_size,
  508. dtype=hf_deepspeed_config.dtype(),
  509. config=hf_deepspeed_config.config,
  510. )
  511. model_parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
  512. optimizer, lr_scheduler = deepspeed_optim_sched(
  513. trainer, hf_deepspeed_config, args, num_training_steps, model_parameters
  514. )
  515. # keep for quick debug:
  516. # from pprint import pprint; pprint(config)
  517. return optimizer, lr_scheduler
  518. def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path, load_module_strict=True):
  519. # it's possible that the user is trying to resume from model_path, which doesn't necessarily
  520. # contain a deepspeed checkpoint. e.g. examples just check if the dir exists and assume it's
  521. # a resume from a checkpoint and not just a local pretrained weight. So we check here if the
  522. # path contains what looks like a deepspeed checkpoint
  523. import glob
  524. deepspeed_checkpoint_dirs = sorted(glob.glob(f"{checkpoint_path}/global_step*"))
  525. if len(deepspeed_checkpoint_dirs) > 0:
  526. logger.info(f"Attempting to resume from {checkpoint_path}")
  527. # this magically updates self.optimizer and self.lr_scheduler
  528. load_path, _ = deepspeed_engine.load_checkpoint(
  529. checkpoint_path,
  530. load_module_strict=load_module_strict,
  531. load_optimizer_states=True,
  532. load_lr_scheduler_states=True,
  533. )
  534. if load_path is None:
  535. raise ValueError(f"[deepspeed] failed to resume from checkpoint {checkpoint_path}")
  536. else:
  537. raise ValueError(f"Can't find a valid checkpoint at {checkpoint_path}")
  538. def propagate_args_to_deepspeed(accelerator, args, auto_find_batch_size=False):
  539. """
  540. Sets values in the deepspeed plugin based on the TrainingArguments.
  541. Args:
  542. accelerator (`Accelerator`): The Accelerator object.
  543. args (`TrainingArguments`): The training arguments to propagate to DeepSpeed config.
  544. auto_find_batch_size (`bool`, *optional*, defaults to `False`):
  545. Whether batch size was auto-discovered by trying increasingly smaller sizes.
  546. """
  547. ds_plugin = accelerator.state.deepspeed_plugin
  548. ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
  549. ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
  550. ds_plugin.hf_ds_config.trainer_config_process(args, auto_find_batch_size)
  551. def deepspeed_sp_compute_loss(accelerator, model, inputs, return_outputs, pc):
  552. """
  553. Computes the loss under sequence parallelism with `sp_backend="deepspeed"` and `sp_size > 1`.
  554. Performs weighted loss aggregation across SP ranks, accounting for varying numbers of valid tokens per rank
  555. (e.g., when some ranks receive only padding or prompt tokens that are masked with -100).
  556. Args:
  557. accelerator (`Accelerator`): The accelerator instance with `torch_device_mesh` support.
  558. model (`torch.nn.Module`): The model to compute the loss for.
  559. inputs (`dict[str, torch.Tensor | Any]`): The input data for the model. Must include `"shift_labels"` key.
  560. return_outputs (`bool`): Whether to return the model outputs along with the loss.
  561. pc (`accelerate.parallelism_config.ParallelismConfig`): The parallelism configuration.
  562. Returns:
  563. The loss, or a tuple of `(loss, outputs)` if `return_outputs` is `True`.
  564. """
  565. # DeepSpeed SP automatically injects shift_labels into inputs (pre-shifted labels for SP).
  566. # The model's forward pass receives shift_labels via **kwargs and passes it to the loss function.
  567. # Both standard transformer models and Liger-patched models handle shift_labels correctly,
  568. # so we can directly use the computed loss from the model output.
  569. # See: https://huggingface.co/docs/accelerate/en/concept_guides/sequence_parallelism
  570. if "labels" not in inputs and "shift_labels" in inputs:
  571. # DeepSpeed SP Dataloader removes "labels" but we need it, otherwise, we won't compute the loss.
  572. inputs["labels"] = inputs["shift_labels"]
  573. outputs = model(**inputs)
  574. loss = outputs.loss
  575. # Prefer DeepSpeed SP groups when using Ulysses; otherwise fall back to torch device mesh.
  576. if pc.sp_backend == "deepspeed" and pc.sp_size > 1:
  577. from deepspeed.utils import groups
  578. sp_group = groups._get_sequence_parallel_group()
  579. elif accelerator.torch_device_mesh is not None:
  580. sp_group = accelerator.torch_device_mesh["sp"].get_group()
  581. else:
  582. raise ValueError(
  583. "Sequence parallelism is enabled but no SP process group is available. "
  584. "Ensure torch_device_mesh is initialized or sp_backend='deepspeed' with sp_size > 1."
  585. )
  586. sp_world_size = pc.sp_size
  587. # differentiable weighted per-shard-loss aggregation across ranks
  588. losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=sp_group)
  589. # special dealing with SFT that has prompt tokens that aren't used in loss computation
  590. good_tokens = (inputs["shift_labels"] != -100).view(-1).sum()
  591. good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=sp_group)
  592. # Skip ranks with zero valid tokens
  593. total_loss = sum(
  594. losses_per_rank[rank] * good_tokens_per_rank[rank]
  595. for rank in range(sp_world_size)
  596. if good_tokens_per_rank[rank] > 0
  597. )
  598. total_good_tokens = sum(good_tokens_per_rank)
  599. loss = total_loss / max(total_good_tokens, 1)
  600. return (loss, outputs) if return_outputs else loss