| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254 |
- # Copyright 2020-present the HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- PyTorch-independent utilities for the Trainer class.
- """
- import contextlib
- import copy
- import functools
- import gc
- import inspect
- import json
- import os
- import random
- import re
- import shutil
- import threading
- import time
- from collections.abc import Callable, Sized
- from functools import partial
- from pathlib import Path
- from typing import Any, NamedTuple, TypeGuard
- import numpy as np
- from .utils import (
- SAFE_WEIGHTS_INDEX_NAME,
- WEIGHTS_INDEX_NAME,
- ExplicitEnum,
- check_torch_load_is_safe,
- is_peft_available,
- is_psutil_available,
- is_torch_available,
- is_torch_cuda_available,
- is_torch_hpu_available,
- is_torch_mlu_available,
- is_torch_mps_available,
- is_torch_musa_available,
- is_torch_npu_available,
- is_torch_xla_available,
- is_torch_xpu_available,
- logging,
- requires_backends,
- )
- logger = logging.get_logger(__name__)
- if is_torch_available():
- import torch
- from safetensors.torch import load_file as safe_load_file
- if is_peft_available() and is_torch_available():
- from peft import PeftMixedModel, PeftModel
- def _is_peft_model(model):
- if is_peft_available():
- return isinstance(model, (PeftModel, PeftMixedModel))
- return False
- def unwrap_peft_model(model):
- """
- Extract the base model from a PEFT-wrapped model.
- If the model is not a PEFT model, returns it unchanged. Otherwise, attempts to
- unwrap the base model using ``get_base_model()`` or the ``base_model.model`` attribute.
- Args:
- model: The model to unwrap.
- Returns:
- The unwrapped base model.
- Raises:
- AttributeError: If the model is a PEFT model but cannot be unwrapped safely.
- """
- if not _is_peft_model(model):
- return model
- if hasattr(model, "get_base_model"):
- return model.get_base_model()
- elif hasattr(model, "base_model") and hasattr(model.base_model, "model"):
- # PeftMixedModel do not provide a `get_base_model` method
- return model.base_model.model
- else:
- raise AttributeError("Cannot extract base model safely from this PEFT wrapper.")
- def validate_quantization_for_training(model):
- """
- Validate that a quantized model is set up correctly for training.
- Raises `ValueError` when:
- - A quantized + compiled model is used (torch.compile is not supported with PEFT fine-tuning).
- - A purely quantized model has no trainable adapters attached (unless it supports QAT).
- - The quantization method does not support training.
- Args:
- model: The model to validate.
- """
- _is_quantized_and_base_model = getattr(model, "is_quantized", False) and not getattr(
- model, "_hf_peft_config_loaded", False
- )
- _quantization_method_supports_training = (
- getattr(model, "hf_quantizer", None) is not None and model.hf_quantizer.is_trainable
- )
- _is_model_quantized_and_qat_trainable = getattr(model, "hf_quantizer", None) is not None and getattr(
- model.hf_quantizer, "is_qat_trainable", False
- )
- # Filter out quantized + compiled models
- if _is_quantized_and_base_model and hasattr(model, "_orig_mod"):
- raise ValueError(
- "You cannot fine-tune quantized model with `torch.compile()` make sure to pass a non-compiled model when fine-tuning a quantized model with PEFT"
- )
- # At this stage the model is already loaded
- if _is_quantized_and_base_model and not _is_peft_model(model) and not _is_model_quantized_and_qat_trainable:
- raise ValueError(
- "You cannot perform fine-tuning on purely quantized models. Please attach trainable adapters on top of"
- " the quantized model to correctly perform fine-tuning. Please see: https://huggingface.co/docs/transformers/peft"
- " for more details"
- )
- elif _is_quantized_and_base_model and not _quantization_method_supports_training:
- raise ValueError(
- f"The model you are trying to fine-tune is quantized with {model.hf_quantizer.quantization_config.quant_method}"
- " but that quantization method do not support training. Please open an issue on GitHub: https://github.com/huggingface/transformers"
- f" to request the support for training support for {model.hf_quantizer.quantization_config.quant_method}"
- )
- def seed_worker(worker_id: int, num_workers: int, rank: int):
- """
- Helper function to set worker seed during Dataloader initialization.
- """
- init_seed = torch.initial_seed() % 2**32
- worker_seed = num_workers * rank + init_seed
- set_seed(worker_seed)
- def enable_full_determinism(seed: int, warn_only: bool = False):
- """
- Helper function for reproducible behavior during distributed training. See
- https://pytorch.org/docs/stable/notes/randomness.html for pytorch
- """
- # set seed first
- set_seed(seed)
- if is_torch_available():
- # Enable PyTorch deterministic mode. This potentially requires either the environment
- # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
- # depending on the CUDA version, so we set them both here
- os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
- os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
- # The environment variable required to enable deterministic mode on Ascend NPUs.
- os.environ["ASCEND_LAUNCH_BLOCKING"] = "1"
- os.environ["HCCL_DETERMINISTIC"] = "1"
- os.environ["FLASH_ATTENTION_DETERMINISTIC"] = "1"
- torch.use_deterministic_algorithms(True, warn_only=warn_only)
- # Enable CUDNN deterministic mode
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = False
- def set_seed(seed: int, deterministic: bool = False):
- """
- Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch` (if installed).
- Args:
- seed (`int`):
- The seed to set.
- deterministic (`bool`, *optional*, defaults to `False`):
- Whether to use deterministic algorithms where available. Can slow down training.
- """
- random.seed(seed)
- np.random.seed(seed)
- if is_torch_available():
- torch.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
- # ^^ safe to call this function even if cuda is not available
- if deterministic:
- torch.use_deterministic_algorithms(True)
- if is_torch_mlu_available():
- torch.mlu.manual_seed_all(seed)
- if is_torch_musa_available():
- torch.musa.manual_seed_all(seed)
- if is_torch_npu_available():
- torch.npu.manual_seed_all(seed)
- if is_torch_hpu_available():
- torch.hpu.manual_seed_all(seed)
- if is_torch_xpu_available():
- torch.xpu.manual_seed_all(seed)
- class EvalPrediction:
- """
- Evaluation output (always contains labels), to be used to compute metrics.
- Parameters:
- predictions (`np.ndarray`): Predictions of the model.
- label_ids (`np.ndarray`): Targets to be matched.
- inputs (`np.ndarray`, *optional*): Input data passed to the model.
- losses (`np.ndarray`, *optional*): Loss values computed during evaluation.
- """
- def __init__(
- self,
- predictions: np.ndarray | tuple[np.ndarray],
- label_ids: np.ndarray | tuple[np.ndarray],
- inputs: np.ndarray | tuple[np.ndarray] | None = None,
- losses: np.ndarray | tuple[np.ndarray] | None = None,
- ):
- self.predictions = predictions
- self.label_ids = label_ids
- self.inputs = inputs
- self.losses = losses
- self.elements = (self.predictions, self.label_ids)
- if self.inputs is not None:
- self.elements += (self.inputs,)
- if self.losses is not None:
- self.elements += (self.losses,)
- def __iter__(self):
- return iter(self.elements)
- def __getitem__(self, idx):
- if idx < 0 or idx >= len(self.elements):
- raise IndexError("tuple index out of range")
- return self.elements[idx]
- class EvalLoopOutput(NamedTuple):
- predictions: np.ndarray | tuple[np.ndarray]
- label_ids: np.ndarray | tuple[np.ndarray] | None
- metrics: dict[str, float] | None
- num_samples: int | None
- class PredictionOutput(NamedTuple):
- predictions: np.ndarray | tuple[np.ndarray]
- label_ids: np.ndarray | tuple[np.ndarray] | None
- metrics: dict[str, float] | None
- class TrainOutput(NamedTuple):
- global_step: int
- training_loss: float
- metrics: dict[str, float]
- PREFIX_CHECKPOINT_DIR = "checkpoint"
- _re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$")
- def get_last_checkpoint(folder):
- content = os.listdir(folder)
- checkpoints = [
- path
- for path in content
- if _re_checkpoint.search(path) is not None and os.path.isdir(os.path.join(folder, path))
- ]
- if len(checkpoints) == 0:
- return
- return os.path.join(folder, max(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0])))
- def sort_checkpoints(
- output_dir: str,
- checkpoint_prefix: str = PREFIX_CHECKPOINT_DIR,
- use_mtime: bool = False,
- best_model_checkpoint: str | None = None,
- ) -> list[str]:
- """
- Return checkpoint directories sorted by step number (oldest first).
- Args:
- output_dir (`str`):
- The directory containing the checkpoints.
- checkpoint_prefix (`str`, *optional*, defaults to `"checkpoint"`):
- The prefix used for checkpoint directory names.
- use_mtime (`bool`, *optional*, defaults to `False`):
- Whether to sort by modification time instead of step number.
- best_model_checkpoint (`str`, *optional*):
- If provided, this checkpoint is moved to second-to-last position to protect
- it from deletion while keeping the most recent checkpoint last for resuming.
- Returns:
- `list[str]`: Sorted list of checkpoint directory paths (oldest first).
- """
- glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
- ordering_and_checkpoint_path = []
- for path in glob_checkpoints:
- if use_mtime:
- ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
- else:
- regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
- if regex_match is not None and regex_match.groups() is not None:
- ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
- checkpoints_sorted = sorted(ordering_and_checkpoint_path)
- # mtime is not reliable on some filesystems (e.g., cloud fuse filesystems)
- # so we check if the mtime is fake and fall back to numerical ordering
- if use_mtime and len(checkpoints_sorted) > 1:
- mtime_diff = checkpoints_sorted[-1][0] - checkpoints_sorted[0][0]
- if mtime_diff < 1.0:
- logger.warning_once("mtime may not be reliable on this filesystem, falling back to numerical ordering")
- return sort_checkpoints(
- output_dir, checkpoint_prefix, use_mtime=False, best_model_checkpoint=best_model_checkpoint
- )
- checkpoints_sorted = [path for _, path in checkpoints_sorted]
- # Move best_model_checkpoint to second-to-last position to protect it from deletion
- # while keeping the most recent checkpoint at the end for resuming training.
- if best_model_checkpoint is not None:
- best_model_checkpoint = str(Path(best_model_checkpoint))
- if best_model_checkpoint in checkpoints_sorted and checkpoints_sorted[-1] != best_model_checkpoint:
- most_recent = checkpoints_sorted[-1]
- checkpoints_sorted = [c for c in checkpoints_sorted if c not in {best_model_checkpoint, most_recent}]
- checkpoints_sorted += [best_model_checkpoint, most_recent]
- return checkpoints_sorted
- def rotate_checkpoints(
- output_dir: str,
- save_total_limit: int | None = None,
- best_model_checkpoint: str | None = None,
- use_mtime: bool = False,
- checkpoint_prefix: str = PREFIX_CHECKPOINT_DIR,
- ) -> None:
- """
- Delete older checkpoints, keeping at most `save_total_limit`.
- Always preserves the most recent checkpoint and the best model checkpoint (if provided).
- Args:
- output_dir (`str`):
- The directory containing the checkpoints.
- save_total_limit (`int`, *optional*):
- Maximum number of checkpoints to keep. No deletion if `None` or <= 0.
- best_model_checkpoint (`str`, *optional*):
- Path to best checkpoint (will always be preserved).
- use_mtime (`bool`, *optional*, defaults to `False`):
- Whether to sort by modification time instead of step number.
- checkpoint_prefix (`str`, *optional*, defaults to `"checkpoint"`):
- The prefix used for checkpoint directory names.
- """
- if save_total_limit is None or save_total_limit <= 0:
- return
- checkpoints = sort_checkpoints(output_dir, checkpoint_prefix, use_mtime)
- if len(checkpoints) <= save_total_limit:
- return
- # Checkpoints that must not be deleted
- protected = {checkpoints[-1]} # most recent, for resuming
- if best_model_checkpoint is not None:
- protected.add(str(Path(best_model_checkpoint)))
- # Delete oldest non-protected checkpoints until we have save_total_limit left
- num_to_keep = max(save_total_limit, len(protected))
- remaining = len(checkpoints)
- for checkpoint in checkpoints:
- if remaining <= num_to_keep:
- break
- if checkpoint not in protected:
- shutil.rmtree(checkpoint, ignore_errors=True)
- remaining -= 1
- class IntervalStrategy(ExplicitEnum):
- NO = "no"
- STEPS = "steps"
- EPOCH = "epoch"
- class SaveStrategy(ExplicitEnum):
- NO = "no"
- STEPS = "steps"
- EPOCH = "epoch"
- BEST = "best"
- class HubStrategy(ExplicitEnum):
- END = "end"
- EVERY_SAVE = "every_save"
- CHECKPOINT = "checkpoint"
- ALL_CHECKPOINTS = "all_checkpoints"
- class BestRun(NamedTuple):
- """
- The best run found by a hyperparameter search (see [`~Trainer.hyperparameter_search`]).
- Parameters:
- run_id (`str`):
- The id of the best run (if models were saved, the corresponding checkpoint will be in the folder ending
- with run-{run_id}).
- objective (`float`):
- The objective that was obtained for this run.
- hyperparameters (`dict[str, Any]`):
- The hyperparameters picked to get this run.
- run_summary (`Optional[Any]`):
- A summary of tuning experiments. `ray.tune.ExperimentAnalysis` object for Ray backend.
- """
- run_id: str
- objective: float | list[float]
- hyperparameters: dict[str, Any]
- run_summary: Any | None = None
- def default_compute_objective(metrics: dict[str, float]) -> float:
- """
- The default objective to maximize/minimize when doing an hyperparameter search. It is the evaluation loss if no
- metrics are provided to the [`Trainer`], the sum of all metrics otherwise.
- Args:
- metrics (`dict[str, float]`): The metrics returned by the evaluate method.
- Return:
- `float`: The objective to minimize or maximize
- """
- metrics = copy.deepcopy(metrics)
- loss = metrics.pop("eval_loss", None)
- _ = metrics.pop("epoch", None)
- # Remove speed metrics
- speed_metrics = [m for m in metrics if m.endswith("_runtime") or m.endswith("_per_second")]
- for sm in speed_metrics:
- _ = metrics.pop(sm, None)
- return loss if len(metrics) == 0 else sum(metrics.values())
- def default_hp_space_optuna(trial) -> dict[str, float]:
- from .integrations import is_optuna_available
- assert is_optuna_available(), "This function needs Optuna installed: `pip install optuna`"
- return {
- "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True),
- "num_train_epochs": trial.suggest_int("num_train_epochs", 1, 5),
- "seed": trial.suggest_int("seed", 1, 40),
- "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [4, 8, 16, 32, 64]),
- }
- def default_hp_space_ray(trial) -> dict[str, Any]:
- from .integrations import is_ray_tune_available
- assert is_ray_tune_available(), "This function needs ray installed: `pip install ray[tune]`"
- from ray import tune
- return {
- "learning_rate": tune.loguniform(1e-6, 1e-4),
- "num_train_epochs": tune.choice(list(range(1, 6))),
- "seed": tune.uniform(1, 40),
- "per_device_train_batch_size": tune.choice([4, 8, 16, 32, 64]),
- }
- def default_hp_space_wandb(trial) -> dict[str, Any]:
- from .integrations import is_wandb_available
- if not is_wandb_available():
- raise ImportError("This function needs wandb installed: `pip install wandb`")
- return {
- "method": "random",
- "metric": {"name": "objective", "goal": "minimize"},
- "parameters": {
- "learning_rate": {"distribution": "uniform", "min": 1e-6, "max": 1e-4},
- "num_train_epochs": {"distribution": "int_uniform", "min": 1, "max": 6},
- "seed": {"distribution": "int_uniform", "min": 1, "max": 40},
- "per_device_train_batch_size": {"values": [4, 8, 16, 32, 64]},
- },
- }
- class HPSearchBackend(ExplicitEnum):
- OPTUNA = "optuna"
- RAY = "ray"
- WANDB = "wandb"
- def is_main_process(local_rank):
- """
- Whether or not the current process is the local process, based on `xr.global_ordinal()` (for TPUs) first, then on
- `local_rank`.
- """
- if is_torch_xla_available():
- import torch_xla.runtime as xr
- return xr.global_ordinal() == 0
- return local_rank in [-1, 0]
- def total_processes_number(local_rank):
- """
- Return the number of processes launched in parallel. Works with `torch.distributed` and TPUs.
- """
- if is_torch_xla_available():
- import torch_xla.runtime as xr
- return xr.world_size()
- elif local_rank != -1 and is_torch_available():
- import torch
- return torch.distributed.get_world_size()
- return 1
- def speed_metrics(split, start_time, num_samples=None, num_steps=None, num_tokens=None):
- """
- Measure and return speed performance metrics.
- This function requires a time snapshot `start_time` before the operation to be measured starts and this function
- should be run immediately after the operation to be measured has completed.
- Args:
- - split: name to prefix metric (like train, eval, test...)
- - start_time: operation start time
- - num_samples: number of samples processed
- - num_steps: number of steps processed
- - num_tokens: number of tokens processed
- """
- runtime = time.time() - start_time
- result = {f"{split}_runtime": round(runtime, 4)}
- if runtime == 0:
- return result
- if num_samples is not None:
- samples_per_second = num_samples / runtime
- result[f"{split}_samples_per_second"] = round(samples_per_second, 3)
- if num_steps is not None:
- steps_per_second = num_steps / runtime
- result[f"{split}_steps_per_second"] = round(steps_per_second, 3)
- if num_tokens is not None:
- tokens_per_second = num_tokens / runtime
- result[f"{split}_tokens_per_second"] = round(tokens_per_second, 3)
- return result
- class SchedulerType(ExplicitEnum):
- """
- Scheduler names for the parameter `lr_scheduler_type` in [`TrainingArguments`].
- By default, it uses "linear". Internally, this retrieves `get_linear_schedule_with_warmup` scheduler from [`Trainer`].
- Scheduler types:
- - "linear" = [`get_linear_schedule_with_warmup`]
- - "cosine" = [`get_cosine_schedule_with_warmup`]
- - "cosine_with_restarts" = [`get_cosine_with_hard_restarts_schedule_with_warmup`]
- - "polynomial" = [`get_polynomial_decay_schedule_with_warmup`]
- - "constant" = [`get_constant_schedule`]
- - "constant_with_warmup" = [`get_constant_schedule_with_warmup`]
- - "inverse_sqrt" = [`get_inverse_sqrt_schedule`]
- - "reduce_lr_on_plateau" = [`get_reduce_on_plateau_schedule`]
- - "cosine_with_min_lr" = [`get_cosine_with_min_lr_schedule_with_warmup`]
- - "cosine_warmup_with_min_lr" = [`get_cosine_with_min_lr_schedule_with_warmup_lr_rate`]
- - "warmup_stable_decay" = [`get_wsd_schedule`]
- - "greedy" = [`get_greedy_schedule`]
- """
- LINEAR = "linear"
- COSINE = "cosine"
- COSINE_WITH_RESTARTS = "cosine_with_restarts"
- POLYNOMIAL = "polynomial"
- CONSTANT = "constant"
- CONSTANT_WITH_WARMUP = "constant_with_warmup"
- INVERSE_SQRT = "inverse_sqrt"
- REDUCE_ON_PLATEAU = "reduce_lr_on_plateau"
- COSINE_WITH_MIN_LR = "cosine_with_min_lr"
- COSINE_WARMUP_WITH_MIN_LR = "cosine_warmup_with_min_lr"
- WARMUP_STABLE_DECAY = "warmup_stable_decay"
- GREEDY = "greedy"
- class TrainerMemoryTracker:
- """
- A helper class that tracks cpu and gpu memory.
- This class will silently skip unless `psutil` is available. Install with `pip install psutil`.
- When a stage completes, it can pass metrics dict to update with the memory metrics gathered during this stage.
- Example :
- ```python
- self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
- self._memory_tracker.start()
- # code ...
- metrics = {"train_runtime": 10.5}
- self._memory_tracker.stop_and_update_metrics(metrics)
- ```
- To understand this class' intricacies please read the documentation of [`~Trainer.log_metrics`].
- """
- # map trainer methods to metrics prefix
- stages = {
- "__init__": "init",
- "train": "train",
- "_inner_training_loop": "train",
- "_finalize_training": "train",
- "evaluate": "eval",
- "predict": "test",
- }
- def __init__(self, skip_memory_metrics=False):
- self.skip_memory_metrics = skip_memory_metrics
- if not is_psutil_available():
- # soft dependency on psutil
- self.skip_memory_metrics = True
- if self.skip_memory_metrics:
- return
- import psutil
- if is_torch_cuda_available() or is_torch_mlu_available() or is_torch_musa_available():
- import torch
- self.torch = torch
- self.gpu = {}
- elif is_torch_mps_available():
- import torch
- self.torch = torch
- self.gpu = {}
- elif is_torch_xpu_available():
- import torch
- self.torch = torch
- self.gpu = {}
- elif is_torch_npu_available():
- import torch
- self.torch = torch
- self.gpu = {}
- elif is_torch_hpu_available():
- import torch
- self.torch = torch
- self.gpu = {}
- else:
- self.torch = None
- self.process = psutil.Process()
- self.cur_stage = None
- self.cpu = {}
- self.init_reported = False
- def derive_stage(self):
- """derives the stage/caller name automatically"""
- caller = inspect.currentframe().f_back.f_back.f_code.co_name
- if caller in self.stages:
- return self.stages[caller]
- else:
- raise ValueError(
- f"was called from {caller}, but only expect to be called from one of {self.stages.keys()}"
- )
- def cpu_mem_used(self):
- """get resident set size memory for the current process"""
- return self.process.memory_info().rss
- def peak_monitor_func(self):
- self.cpu_mem_used_peak = -1
- while True:
- self.cpu_mem_used_peak = max(self.cpu_mem_used(), self.cpu_mem_used_peak)
- # can't sleep or will not catch the peak right (this comment is here on purpose)
- # time.sleep(0.001) # 1msec
- if not self.peak_monitoring:
- break
- def start(self):
- """start tracking for the caller's stage"""
- if self.skip_memory_metrics:
- return
- stage = self.derive_stage()
- # deal with nested calls of eval during train - simply ignore those
- if self.cur_stage is not None and self.cur_stage != stage:
- return
- self.cur_stage = stage
- gc.collect()
- if self.torch is not None:
- if torch.cuda.is_available():
- self.torch.cuda.reset_peak_memory_stats()
- self.torch.cuda.empty_cache()
- elif is_torch_mlu_available():
- self.torch.mlu.reset_peak_memory_stats()
- self.torch.mlu.empty_cache()
- elif is_torch_musa_available():
- self.torch.musa.reset_peak_memory_stats()
- self.torch.musa.empty_cache()
- elif is_torch_xpu_available():
- self.torch.xpu.reset_peak_memory_stats()
- self.torch.xpu.empty_cache()
- elif is_torch_npu_available():
- self.torch.npu.reset_peak_memory_stats()
- self.torch.npu.empty_cache()
- elif is_torch_hpu_available():
- self.torch.hpu.reset_peak_memory_stats()
- # not available on hpu as it reserves all device memory for the current process
- # self.torch.hpu.empty_cache()
- elif is_torch_mps_available():
- self.torch.mps.empty_cache()
- # gpu
- if self.torch is not None:
- if torch.cuda.is_available():
- self.gpu_mem_used_at_start = self.torch.cuda.memory_allocated()
- elif is_torch_mlu_available():
- self.gpu_mem_used_at_start = self.torch.mlu.memory_allocated()
- elif is_torch_musa_available():
- self.gpu_mem_used_at_start = self.torch.musa.memory_allocated()
- elif is_torch_xpu_available():
- self.gpu_mem_used_at_start = self.torch.xpu.memory_allocated()
- elif is_torch_npu_available():
- self.gpu_mem_used_at_start = self.torch.npu.memory_allocated()
- elif is_torch_hpu_available():
- self.gpu_mem_used_at_start = self.torch.hpu.memory_allocated()
- elif is_torch_mps_available():
- self.gpu_mem_used_at_start = self.torch.mps.current_allocated_memory()
- # cpu
- self.cpu_mem_used_at_start = self.cpu_mem_used()
- self.peak_monitoring = True
- peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
- peak_monitor_thread.daemon = True
- peak_monitor_thread.start()
- def stop(self, stage):
- """stop tracking for the passed stage"""
- # deal with nested calls of eval during train - simply ignore those
- if self.cur_stage is not None and self.cur_stage != stage:
- return
- # this sends a signal to peak_monitor_func to complete its loop
- self.peak_monitoring = False
- # first ensure all objects get collected and their memory is freed
- gc.collect()
- if self.torch is not None:
- if torch.cuda.is_available():
- self.torch.cuda.empty_cache()
- elif is_torch_mlu_available():
- self.torch.mlu.empty_cache()
- elif is_torch_musa_available():
- self.torch.musa.empty_cache()
- elif is_torch_xpu_available():
- self.torch.xpu.empty_cache()
- elif is_torch_npu_available():
- self.torch.npu.empty_cache()
- elif is_torch_hpu_available():
- # not available on hpu as it reserves all device memory for the current process
- # self.torch.npu.empty_cache()
- pass
- elif is_torch_mps_available():
- self.torch.mps.empty_cache()
- # concepts:
- # - alloc_delta: the difference of allocated memory between the end and the start
- # - peaked_delta: the difference between the peak memory and the current memory
- # in order to know how much memory the measured code consumed one needs to sum these two
- # gpu
- if self.torch is not None:
- if torch.cuda.is_available():
- self.gpu_mem_used_now = self.torch.cuda.memory_allocated()
- self.gpu_mem_used_peak = self.torch.cuda.max_memory_allocated()
- elif is_torch_mlu_available():
- self.gpu_mem_used_now = self.torch.mlu.memory_allocated()
- self.gpu_mem_used_peak = self.torch.mlu.max_memory_allocated()
- elif is_torch_musa_available():
- self.gpu_mem_used_now = self.torch.musa.memory_allocated()
- self.gpu_mem_used_peak = self.torch.musa.max_memory_allocated()
- elif is_torch_xpu_available():
- self.gpu_mem_used_now = self.torch.xpu.memory_allocated()
- self.gpu_mem_used_peak = self.torch.xpu.max_memory_allocated()
- elif is_torch_npu_available():
- self.gpu_mem_used_now = self.torch.npu.memory_allocated()
- self.gpu_mem_used_peak = self.torch.npu.max_memory_allocated()
- elif is_torch_hpu_available():
- self.gpu_mem_used_now = self.torch.hpu.memory_allocated()
- self.gpu_mem_used_peak = self.torch.hpu.max_memory_allocated()
- elif is_torch_mps_available():
- self.gpu_mem_used_now = self.torch.mps.current_allocated_memory()
- # self.torch.mps.max_memory_allocated() does not exist yet
- self.gpu_mem_used_peak = None
- else:
- raise ValueError("No available GPU device found!")
- self.gpu[self.cur_stage] = {
- "begin": self.gpu_mem_used_at_start,
- "end": self.gpu_mem_used_now,
- "alloc": (self.gpu_mem_used_now - self.gpu_mem_used_at_start),
- }
- if self.gpu_mem_used_peak is not None:
- self.gpu[self.cur_stage]["peaked"] = max(0, self.gpu_mem_used_peak - self.gpu_mem_used_now)
- else:
- self.gpu[self.cur_stage]["peaked"] = "Not available"
- # cpu
- self.cpu_mem_used_now = self.cpu_mem_used()
- self.cpu[self.cur_stage] = {
- "begin": self.cpu_mem_used_at_start,
- "end": self.cpu_mem_used_now,
- "alloc": (self.cpu_mem_used_now - self.cpu_mem_used_at_start),
- "peaked": max(0, self.cpu_mem_used_peak - self.cpu_mem_used_now),
- }
- # reset - cycle finished
- self.cur_stage = None
- def update_metrics(self, stage, metrics):
- """updates the metrics"""
- if self.skip_memory_metrics:
- return
- # deal with nested calls of eval during train - simply ignore those
- if self.cur_stage is not None and self.cur_stage != stage:
- return
- # since we don't have a way to return init metrics, we push them into the first of train/val/predict
- stages = [stage]
- if not self.init_reported:
- stages.insert(0, "init")
- self.init_reported = True
- for stage in stages:
- for t in ["alloc", "peaked"]:
- if stage in self.cpu and t in self.cpu[stage]:
- metrics[f"{stage}_mem_cpu_{t}_delta"] = self.cpu[stage][t]
- if self.torch is not None and stage in self.gpu and t in self.gpu[stage]:
- metrics[f"{stage}_mem_gpu_{t}_delta"] = self.gpu[stage][t]
- # if we need additional debug info, enable the following
- # for t in ["begin", "end"]:
- # if stage in self.cpu and t in self.cpu[stage]:
- # metrics[f"{stage}_mem_cpu_{t}"] = self.cpu[stage][t]
- # if self.torch is not None and stage in self.gpu and t in self.gpu[stage]:
- # metrics[f"{stage}_mem_gpu_{t}"] = self.gpu[stage][t]
- # since memory can be allocated before init, and it might be difficult to track overall
- # memory usage, in particular for GPU, let's report memory usage at the point init was called
- if stages[0] == "init":
- metrics["before_init_mem_cpu"] = self.cpu["init"]["begin"]
- if self.torch is not None:
- metrics["before_init_mem_gpu"] = self.gpu["init"]["begin"]
- # if we also wanted to report any additional memory allocations in between init and
- # whatever the next stage was we could also report this:
- # if self.cpu["init"]["end"] != self.cpu[stage]["begin"]:
- # metrics[f"after_init_mem_cpu_delta"] = self.cpu[stage]["begin"] - self.cpu["init"]["end"]
- # if self.torch is not None and self.gpu["init"]["end"] != self.gpu[stage]["begin"]:
- # metrics[f"after_init_mem_gpu_delta"] = self.gpu[stage]["begin"] - self.gpu["init"]["end"]
- def stop_and_update_metrics(self, metrics=None):
- """combine stop and metrics update in one call for simpler code"""
- if self.skip_memory_metrics:
- return
- stage = self.derive_stage()
- self.stop(stage)
- # init doesn't have metrics to update so we just save that data for later stages to retrieve
- if metrics is not None:
- self.update_metrics(stage, metrics)
- def has_length(dataset: Any) -> TypeGuard[Sized]:
- """
- Checks if the dataset implements __len__() and it doesn't raise an error
- """
- try:
- return len(dataset) is not None
- except TypeError:
- # TypeError: len() of unsized object
- return False
- except AttributeError:
- # Ray DataSets raises an AttributeError: https://github.com/ray-project/ray/blob/master/python/ray/data/dataset.py#L5616
- return False
- def denumpify_detensorize(metrics):
- """
- Recursively calls `.item()` on the element of the dictionary passed
- """
- if isinstance(metrics, (list, tuple)):
- return type(metrics)(denumpify_detensorize(m) for m in metrics)
- elif isinstance(metrics, dict):
- return type(metrics)({k: denumpify_detensorize(v) for k, v in metrics.items()})
- elif isinstance(metrics, np.generic):
- return metrics.item()
- elif is_torch_available() and isinstance(metrics, torch.Tensor) and metrics.numel() == 1:
- return metrics.item()
- return metrics
- def number_of_arguments(func):
- """
- Return the number of arguments of the passed function, even if it's a partial function.
- """
- if isinstance(func, functools.partial):
- total_args = len(inspect.signature(func.func).parameters)
- return total_args - len(func.args) - len(func.keywords)
- return len(inspect.signature(func).parameters)
- def find_executable_batch_size(
- function: Callable | None = None, starting_batch_size: int = 128, auto_find_batch_size: bool = False
- ):
- """
- Args:
- A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or
- CUDNN, the batch size is multiplied by 0.9 and passed to `function`. `function` must take in a `batch_size` parameter as
- its first argument.
- function (`Callable`, *optional*)
- A function to wrap
- starting_batch_size (`int`, *optional*)
- The batch size to try and fit into memory
- auto_find_batch_size (`bool`, *optional*)
- If False, will just execute `function`
- """
- if function is None:
- return functools.partial(
- find_executable_batch_size,
- starting_batch_size=starting_batch_size,
- auto_find_batch_size=auto_find_batch_size,
- )
- if auto_find_batch_size:
- requires_backends(find_executable_batch_size, "accelerate")
- from accelerate.utils import find_executable_batch_size as accelerate_find_executable_batch_size
- return accelerate_find_executable_batch_size(function=function, starting_batch_size=starting_batch_size)
- return functools.partial(function, batch_size=starting_batch_size)
- class FSDPOption(ExplicitEnum):
- FULL_SHARD = "full_shard"
- SHARD_GRAD_OP = "shard_grad_op"
- NO_SHARD = "no_shard"
- HYBRID_SHARD = "hybrid_shard"
- HYBRID_SHARD_ZERO2 = "hybrid_shard_zero2"
- OFFLOAD = "offload"
- AUTO_WRAP = "auto_wrap"
- class RemoveColumnsCollator:
- """Wrap the data collator to remove unused columns before they are passed to the collator."""
- def __init__(
- self,
- data_collator,
- signature_columns,
- logger=None,
- model_name: str | None = None,
- description: str | None = None,
- ):
- self.data_collator = data_collator
- self.signature_columns = signature_columns
- self.logger = logger
- self.description = description
- self.model_name = model_name
- self.message_logged = False
- def _remove_columns(self, feature: dict) -> dict:
- if not isinstance(feature, dict):
- return feature
- if not self.message_logged and self.logger and self.model_name:
- ignored_columns = list(set(feature.keys()) - set(self.signature_columns))
- if len(ignored_columns) > 0:
- dset_description = "" if self.description is None else f"in the {self.description} set"
- self.logger.info(
- f"The following columns {dset_description} don't have a corresponding argument in "
- f"`{self.model_name}.forward` and have been ignored: {', '.join(ignored_columns)}."
- f" If {', '.join(ignored_columns)} are not expected by `{self.model_name}.forward`, "
- " you can safely ignore this message."
- )
- self.message_logged = True
- return {k: v for k, v in feature.items() if k in self.signature_columns}
- def __call__(self, features: list[dict]):
- features = [self._remove_columns(feature) for feature in features]
- return self.data_collator(features)
- def check_target_module_exists(optim_target_modules, key: str, return_is_regex: bool = False):
- """A helper method to check if the passed module's key name matches any of the target modules in the optim_target_modules.
- Args:
- optim_target_modules (`Union[str, list[str]]`):
- A list of strings to try to match. Can be also a full string.
- key (`str`):
- A key to search any matches in optim_target_modules
- return_is_regex (`bool`):
- If set to `True`, the method will return whether the passed `optim_target_modules`
- is a regex or not.
- Returns:
- `bool` : True of match object if key matches any target modules from config, False or
- None if no match found
- `bool` : If the matched target module is a regex to silence out the warnings in Trainer
- for extra modules being found (only if `target_module_found=True` for an array of regex).
- """
- target_module_found = False
- is_regex = False
- if isinstance(optim_target_modules, str):
- target_module_found = bool(re.fullmatch(optim_target_modules, key))
- is_regex = optim_target_modules != key
- elif key in optim_target_modules: # from here, target_module_found must be a list of str
- # this module is specified directly in target_modules
- target_module_found = True
- elif any(target_key in key for target_key in optim_target_modules):
- target_module_found = True
- elif any(bool(re.fullmatch(optim_target_module, key)) for optim_target_module in optim_target_modules):
- target_module_found = True
- is_regex = True
- if return_is_regex:
- return target_module_found, is_regex
- return target_module_found
- def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
- """
- This is the same as
- [`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict)
- but for a sharded checkpoint.
- This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
- loaded in the model.
- Args:
- model (`torch.nn.Module`): The model in which to load the checkpoint.
- folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint.
- strict (`bool`, *optional*, defaults to `True`):
- Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
- prefer_safe (`bool`, *optional*, defaults to `True`):
- If both safetensors and PyTorch save files are present in checkpoint and `prefer_safe` is True, the
- safetensors files will be loaded. Otherwise, PyTorch files are always loaded when possible.
- Returns:
- `NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields
- - `missing_keys` is a list of str containing the missing keys
- - `unexpected_keys` is a list of str containing the unexpected keys
- """
- # Load the index
- index_file = os.path.join(folder, WEIGHTS_INDEX_NAME)
- safe_index_file = os.path.join(folder, SAFE_WEIGHTS_INDEX_NAME)
- index_present = os.path.isfile(index_file)
- safe_index_present = os.path.isfile(safe_index_file)
- if not index_present and not safe_index_present:
- filenames = (WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME)
- raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {folder}.")
- load_safe = safe_index_present and (prefer_safe or not index_present)
- load_index = safe_index_file if load_safe else index_file
- with open(load_index, "r", encoding="utf-8") as f:
- index = json.load(f)
- shard_files = list(set(index["weight_map"].values()))
- # If strict=True, error before loading any of the state dicts.
- # TODO: Here, update the weight map with the config.dynamic_weight_conversion
- loaded_keys = index["weight_map"].keys()
- model_keys = model.state_dict().keys()
- missing_keys = [key for key in model_keys if key not in loaded_keys]
- unexpected_keys = [key for key in loaded_keys if key not in model_keys]
- if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0):
- error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
- if len(missing_keys) > 0:
- str_missing_keys = ",".join([f'"{k}"' for k in missing_keys])
- error_message += f"\nMissing key(s): {str_missing_keys}."
- if len(unexpected_keys) > 0:
- str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys])
- error_message += f"\nUnexpected key(s): {str_unexpected_keys}."
- raise RuntimeError(error_message)
- if load_safe:
- loader = safe_load_file
- else:
- check_torch_load_is_safe()
- loader = partial(torch.load, map_location="cpu", weights_only=True)
- for shard_file in shard_files:
- state_dict = loader(os.path.join(folder, shard_file))
- model.load_state_dict(state_dict, strict=False)
- # Make sure memory is freed before we load the next state dict.
- del state_dict
- gc.collect()
- # Return the same thing as PyTorch load_state_dict function.
- return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
- def compare_trainer_and_checkpoint_args(training_args, trainer_state):
- """
- Compare training arguments with those stored in a checkpoint's trainer state.
- Logs a warning if there are mismatches between the current training arguments
- and the ones saved in the checkpoint.
- Args:
- training_args: The current training arguments.
- trainer_state: The trainer state loaded from a checkpoint.
- """
- attributes_map = {
- "logging_steps": "logging_steps",
- "eval_steps": "eval_steps",
- "save_steps": "save_steps",
- }
- has_warning = False
- warning_str = "Warning: The following arguments do not match the ones in the `trainer_state.json` within the checkpoint directory: "
- for arg_attr, state_attr in attributes_map.items():
- arg_value = getattr(training_args, arg_attr, None)
- state_value = getattr(trainer_state, state_attr, None)
- if arg_value is not None and state_value is not None and arg_value != state_value:
- warning_str += f"\n\t{arg_attr}: {arg_value} (from args) != {state_value} (from trainer_state.json)"
- has_warning = True
- # train bs is special as we need to account for multi-GPU
- train_bs_args = training_args.per_device_train_batch_size
- train_bs_state = trainer_state.train_batch_size // max(1, training_args.n_gpu)
- if train_bs_args != train_bs_state:
- warning_str += f"\n\tper_device_train_batch_size: {train_bs_args} (from args) != {train_bs_state} (from trainer_state.json)"
- has_warning = True
- if has_warning:
- logger.warning_once(warning_str)
- def align_special_tokens(model, processing_class):
- """
- Aligns the special tokens of the tokenizer with the model configs.
- A new tokens may be defined in the tokenizer for fine-tuning purposes, e.g. an "end of turn" token may be
- added on chat models. In that case, we want the model configs to be aligned with the tokenizer, so that all
- downstream uses work as expected. This alignment should happen before training, to ensure the prediction step
- uses the new tokens as well.
- """
- from .processing_utils import ProcessorMixin
- from .tokenization_utils_base import PreTrainedTokenizerBase
- if isinstance(processing_class, ProcessorMixin):
- tokenizer: PreTrainedTokenizerBase = processing_class.tokenizer
- else:
- tokenizer = processing_class
- model_has_generation_config = hasattr(model, "generation_config") and model.generation_config is not None
- updated_tokens = {}
- # 1 - Align EOS token. EOS is more complex than the others, as `generation_config` may hold more than one EOS
- # token.
- tokenizer_has_new_eos = tokenizer.eos_token_id != getattr(model.config, "eos_token_id", None)
- if model_has_generation_config:
- # `generation_config.eos_token_id` is None: direct comparison
- if model.generation_config.eos_token_id is None:
- tokenizer_has_new_eos |= tokenizer.eos_token_id != model.generation_config.eos_token_id
- else:
- # `generation_config.eos_token_id` is an `int`: convert it to list (and continue below)
- if isinstance(model.generation_config.eos_token_id, int):
- model.generation_config.eos_token_id = [model.generation_config.eos_token_id]
- # `generation_config.eos_token_id` is a `list`: check if the tokenizer's EOS token is in the list
- tokenizer_has_new_eos |= tokenizer.eos_token_id not in model.generation_config.eos_token_id
- if tokenizer_has_new_eos:
- updated_tokens["eos_token_id"] = tokenizer.eos_token_id
- model.config.eos_token_id = tokenizer.eos_token_id
- # The generation config may hold more than one EOS token. We preserve the original EOS tokens: any of the
- # EOS tokens defined here will halt generation.
- if model_has_generation_config:
- all_eos_tokens = [tokenizer.eos_token_id]
- if model.generation_config.eos_token_id is not None:
- all_eos_tokens += list(model.generation_config.eos_token_id)
- model.generation_config.eos_token_id = [token for token in all_eos_tokens if token is not None]
- # 2 - Align BOS
- tokenizer_has_new_bos = tokenizer.bos_token_id != getattr(model.config, "bos_token_id", None)
- if model_has_generation_config:
- tokenizer_has_new_bos |= tokenizer.bos_token_id != model.generation_config.bos_token_id
- if tokenizer_has_new_bos:
- updated_tokens["bos_token_id"] = tokenizer.bos_token_id
- model.config.bos_token_id = tokenizer.bos_token_id
- if model_has_generation_config:
- model.generation_config.bos_token_id = tokenizer.bos_token_id
- # 3 - Align PAD
- tokenizer_has_new_pad = tokenizer.pad_token_id != getattr(model.config, "pad_token_id", None)
- if model_has_generation_config:
- tokenizer_has_new_pad |= tokenizer.pad_token_id != model.generation_config.pad_token_id
- if tokenizer_has_new_pad:
- updated_tokens["pad_token_id"] = tokenizer.pad_token_id
- model.config.pad_token_id = tokenizer.pad_token_id
- if model_has_generation_config:
- model.generation_config.pad_token_id = tokenizer.pad_token_id
- # 4 - Warn users about the changes
- if len(updated_tokens) > 0:
- logger.warning(
- "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. "
- "The model config and generation config were aligned accordingly, being updated with the tokenizer's "
- f"values. Updated tokens: {updated_tokens}."
- )
- @contextlib.contextmanager
- def suppress_progress_bars():
- """Context manager that suppresses huggingface_hub progress bars."""
- import huggingface_hub.utils as hf_hub_utils
- hf_hub_utils.disable_progress_bars()
- try:
- yield
- finally:
- hf_hub_utils.enable_progress_bars()
|