tpu.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import functools
  15. import os
  16. import torch
  17. from torch.utils.data import DataLoader
  18. from ..utils import WEIGHTS_NAME, PushToHubMixin, is_torch_xla_available, logging
  19. logger = logging.get_logger(__name__)
  20. def tpu_spmd_dataloader(dataloader: DataLoader):
  21. if is_torch_xla_available():
  22. import torch_xla.distributed.parallel_loader as pl
  23. assert isinstance(dataloader, pl.MpDeviceLoader), (
  24. "The dataloader must be a `torch_xla.distributed.parallel_loader.MpDeviceLoader`."
  25. )
  26. # This is to support PyTorch/XLA FSDP via SPMD.
  27. # Here we shard the input data's 0th dim across the fsdp axis.
  28. import torch_xla.distributed.spmd as xs
  29. sharding_spec = xs.ShardingSpec(xs.get_global_mesh(), ("fsdp", None))
  30. dataloader._parallel_loader_kwargs["input_sharding"] = sharding_spec
  31. return dataloader
  32. else:
  33. return dataloader
  34. def wrap_model_xla_fsdp(model, args, is_fsdp_xla_v2_enabled):
  35. """
  36. Wraps a model with XLA Fully Sharded Data Parallelism (FSDP).
  37. Handles both FSDP v1 (`XlaFullyShardedDataParallel`) and v2 (`SpmdFullyShardedDataParallel`),
  38. including auto-wrap policies, gradient checkpointing, and patching `xm.optimizer_step`.
  39. Args:
  40. model (`torch.nn.Module`): The model to wrap.
  41. args (`TrainingArguments`): The training arguments containing FSDP configuration.
  42. is_fsdp_xla_v2_enabled (`bool`): Whether FSDP v2 (SPMD) is enabled.
  43. Returns:
  44. `torch.nn.Module`: The FSDP-wrapped model.
  45. """
  46. import torch_xla.core.xla_model as xm
  47. import torch_xla.distributed.spmd as xs
  48. from ..trainer_pt_utils import get_module_class_from_name
  49. try:
  50. from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
  51. from torch_xla.distributed.fsdp import checkpoint_module
  52. from torch_xla.distributed.fsdp.wrap import (
  53. size_based_auto_wrap_policy,
  54. transformer_auto_wrap_policy,
  55. )
  56. if is_fsdp_xla_v2_enabled:
  57. from torch_xla.experimental.spmd_fully_sharded_data_parallel import (
  58. SpmdFullyShardedDataParallel as FSDPv2,
  59. )
  60. except ImportError:
  61. raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
  62. auto_wrap_policy = None
  63. auto_wrapper_callable = None
  64. default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None)
  65. fsdp_transformer_layer_cls_to_wrap = args.fsdp_config.get(
  66. "transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap
  67. )
  68. if args.fsdp_config["min_num_params"] > 0:
  69. auto_wrap_policy = functools.partial(
  70. size_based_auto_wrap_policy, min_num_params=args.fsdp_config["min_num_params"]
  71. )
  72. elif fsdp_transformer_layer_cls_to_wrap is not None:
  73. transformer_cls_to_wrap = set()
  74. for layer_class in fsdp_transformer_layer_cls_to_wrap:
  75. transformer_cls = get_module_class_from_name(model, layer_class)
  76. if transformer_cls is None:
  77. raise Exception("Could not find the transformer layer class to wrap in the model.")
  78. else:
  79. transformer_cls_to_wrap.add(transformer_cls)
  80. auto_wrap_policy = functools.partial(
  81. transformer_auto_wrap_policy,
  82. # Transformer layer class to wrap
  83. transformer_layer_cls=transformer_cls_to_wrap,
  84. )
  85. fsdp_kwargs = args.xla_fsdp_config
  86. if args.fsdp_config["xla_fsdp_grad_ckpt"]:
  87. if model.config.use_cache:
  88. logger.warning_once(
  89. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  90. )
  91. model.config.use_cache = False
  92. # Apply gradient checkpointing to auto-wrapped sub-modules if specified
  93. def auto_wrapper_callable(m, *args, **kwargs):
  94. target_cls = FSDP if not is_fsdp_xla_v2_enabled else FSDPv2
  95. return target_cls(checkpoint_module(m), *args, **kwargs)
  96. # Wrap the base model with an outer FSDP wrapper
  97. if is_fsdp_xla_v2_enabled:
  98. def shard_output(output, mesh):
  99. from ..modeling_outputs import CausalLMOutputWithPast
  100. real_output = None
  101. if isinstance(output, torch.Tensor):
  102. real_output = output
  103. elif isinstance(output, tuple):
  104. real_output = output[0]
  105. elif isinstance(output, CausalLMOutputWithPast):
  106. real_output = output.logits
  107. if real_output is None:
  108. raise ValueError("Something went wrong, the output of the model shouldn't be `None`")
  109. xs.mark_sharding(real_output, mesh, ("fsdp", None, None))
  110. model = FSDPv2(
  111. model,
  112. shard_output=shard_output,
  113. auto_wrap_policy=auto_wrap_policy,
  114. auto_wrapper_callable=auto_wrapper_callable,
  115. )
  116. else:
  117. model = FSDP(
  118. model,
  119. auto_wrap_policy=auto_wrap_policy,
  120. auto_wrapper_callable=auto_wrapper_callable,
  121. **fsdp_kwargs,
  122. )
  123. # Patch `xm.optimizer_step` should not reduce gradients in this case,
  124. # as FSDP does not need gradient reduction over sharded parameters.
  125. def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
  126. loss = optimizer.step(**optimizer_args)
  127. if barrier:
  128. xm.mark_step()
  129. return loss
  130. xm.optimizer_step = patched_optimizer_step
  131. return model
  132. def save_tpu_checkpoint(model, args, accelerator, processing_class, is_fsdp_xla_v1_enabled, output_dir=None):
  133. """
  134. Saves a model checkpoint on TPU/XLA devices.
  135. Handles FSDP v1 sharded checkpoints (with consolidation on master), as well as
  136. standard XLA model saving via `save_pretrained` or `xm.save`.
  137. Args:
  138. model (`torch.nn.Module`): The model to save.
  139. args (`TrainingArguments`): The training arguments.
  140. accelerator (`Accelerator`): The accelerator instance.
  141. processing_class: The processing class (tokenizer/processor) to save alongside the model.
  142. is_fsdp_xla_v1_enabled (`bool`): Whether FSDP XLA v1 is enabled.
  143. output_dir (`str`, *optional*): The directory to save to. Defaults to `args.output_dir`.
  144. """
  145. import torch_xla.core.xla_model as xm
  146. output_dir = output_dir if output_dir is not None else args.output_dir
  147. logger.info(f"Saving model checkpoint to {output_dir}")
  148. xm.mark_step()
  149. if xm.is_master_ordinal(local=False):
  150. os.makedirs(output_dir, exist_ok=True)
  151. torch.save(args, os.path.join(output_dir, "training_args.bin"))
  152. # Save a trained model and configuration using `save_pretrained()`.
  153. # They can then be reloaded using `from_pretrained()`
  154. supported_classes = (PushToHubMixin,)
  155. xm.rendezvous("saving_checkpoint")
  156. if is_fsdp_xla_v1_enabled:
  157. ckpt = {
  158. "model": model.state_dict(),
  159. "shard_metadata": model.get_shard_metadata(),
  160. }
  161. ckpt_path = os.path.join(output_dir, f"rank{args.process_index}-of-{args.world_size}-{WEIGHTS_NAME}")
  162. # All ranks save sharded checkpoint
  163. xm.save(ckpt, ckpt_path, master_only=False)
  164. # Make sure all ranks have saved checkpoints
  165. xm.rendezvous("save_full_checkpoints")
  166. # Master save full checkpoint
  167. if args.should_save:
  168. from torch_xla.distributed.fsdp import consolidate_sharded_model_checkpoints
  169. full_state_dict, _ = consolidate_sharded_model_checkpoints(
  170. ckpt_prefix=os.path.join(output_dir, ""),
  171. ckpt_suffix=f"rank*-of-*-{WEIGHTS_NAME}",
  172. save_model=False,
  173. )
  174. model = model.module.module
  175. unwrapped_model = accelerator.unwrap_model(model)
  176. if isinstance(unwrapped_model, supported_classes):
  177. unwrapped_model.save_pretrained(output_dir, state_dict=full_state_dict)
  178. else:
  179. logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
  180. xm.save(full_state_dict, os.path.join(output_dir, WEIGHTS_NAME))
  181. elif not isinstance(model, supported_classes):
  182. if isinstance(accelerator.unwrap_model(model), supported_classes):
  183. accelerator.unwrap_model(model).save_pretrained(
  184. output_dir,
  185. is_main_process=args.should_save,
  186. state_dict=xm._maybe_convert_to_cpu(model.state_dict()),
  187. )
  188. else:
  189. logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
  190. state_dict = xm._maybe_convert_to_cpu(model.state_dict())
  191. xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
  192. else:
  193. model.save_pretrained(
  194. output_dir,
  195. is_main_process=args.should_save,
  196. state_dict=xm._maybe_convert_to_cpu(model.state_dict()),
  197. )
  198. if processing_class is not None and args.should_save:
  199. processing_class.save_pretrained(output_dir)