trainer_utils.py 49 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254
  1. # Copyright 2020-present the HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. PyTorch-independent utilities for the Trainer class.
  16. """
  17. import contextlib
  18. import copy
  19. import functools
  20. import gc
  21. import inspect
  22. import json
  23. import os
  24. import random
  25. import re
  26. import shutil
  27. import threading
  28. import time
  29. from collections.abc import Callable, Sized
  30. from functools import partial
  31. from pathlib import Path
  32. from typing import Any, NamedTuple, TypeGuard
  33. import numpy as np
  34. from .utils import (
  35. SAFE_WEIGHTS_INDEX_NAME,
  36. WEIGHTS_INDEX_NAME,
  37. ExplicitEnum,
  38. check_torch_load_is_safe,
  39. is_peft_available,
  40. is_psutil_available,
  41. is_torch_available,
  42. is_torch_cuda_available,
  43. is_torch_hpu_available,
  44. is_torch_mlu_available,
  45. is_torch_mps_available,
  46. is_torch_musa_available,
  47. is_torch_npu_available,
  48. is_torch_xla_available,
  49. is_torch_xpu_available,
  50. logging,
  51. requires_backends,
  52. )
  53. logger = logging.get_logger(__name__)
  54. if is_torch_available():
  55. import torch
  56. from safetensors.torch import load_file as safe_load_file
  57. if is_peft_available() and is_torch_available():
  58. from peft import PeftMixedModel, PeftModel
  59. def _is_peft_model(model):
  60. if is_peft_available():
  61. return isinstance(model, (PeftModel, PeftMixedModel))
  62. return False
  63. def unwrap_peft_model(model):
  64. """
  65. Extract the base model from a PEFT-wrapped model.
  66. If the model is not a PEFT model, returns it unchanged. Otherwise, attempts to
  67. unwrap the base model using ``get_base_model()`` or the ``base_model.model`` attribute.
  68. Args:
  69. model: The model to unwrap.
  70. Returns:
  71. The unwrapped base model.
  72. Raises:
  73. AttributeError: If the model is a PEFT model but cannot be unwrapped safely.
  74. """
  75. if not _is_peft_model(model):
  76. return model
  77. if hasattr(model, "get_base_model"):
  78. return model.get_base_model()
  79. elif hasattr(model, "base_model") and hasattr(model.base_model, "model"):
  80. # PeftMixedModel do not provide a `get_base_model` method
  81. return model.base_model.model
  82. else:
  83. raise AttributeError("Cannot extract base model safely from this PEFT wrapper.")
  84. def validate_quantization_for_training(model):
  85. """
  86. Validate that a quantized model is set up correctly for training.
  87. Raises `ValueError` when:
  88. - A quantized + compiled model is used (torch.compile is not supported with PEFT fine-tuning).
  89. - A purely quantized model has no trainable adapters attached (unless it supports QAT).
  90. - The quantization method does not support training.
  91. Args:
  92. model: The model to validate.
  93. """
  94. _is_quantized_and_base_model = getattr(model, "is_quantized", False) and not getattr(
  95. model, "_hf_peft_config_loaded", False
  96. )
  97. _quantization_method_supports_training = (
  98. getattr(model, "hf_quantizer", None) is not None and model.hf_quantizer.is_trainable
  99. )
  100. _is_model_quantized_and_qat_trainable = getattr(model, "hf_quantizer", None) is not None and getattr(
  101. model.hf_quantizer, "is_qat_trainable", False
  102. )
  103. # Filter out quantized + compiled models
  104. if _is_quantized_and_base_model and hasattr(model, "_orig_mod"):
  105. raise ValueError(
  106. "You cannot fine-tune quantized model with `torch.compile()` make sure to pass a non-compiled model when fine-tuning a quantized model with PEFT"
  107. )
  108. # At this stage the model is already loaded
  109. if _is_quantized_and_base_model and not _is_peft_model(model) and not _is_model_quantized_and_qat_trainable:
  110. raise ValueError(
  111. "You cannot perform fine-tuning on purely quantized models. Please attach trainable adapters on top of"
  112. " the quantized model to correctly perform fine-tuning. Please see: https://huggingface.co/docs/transformers/peft"
  113. " for more details"
  114. )
  115. elif _is_quantized_and_base_model and not _quantization_method_supports_training:
  116. raise ValueError(
  117. f"The model you are trying to fine-tune is quantized with {model.hf_quantizer.quantization_config.quant_method}"
  118. " but that quantization method do not support training. Please open an issue on GitHub: https://github.com/huggingface/transformers"
  119. f" to request the support for training support for {model.hf_quantizer.quantization_config.quant_method}"
  120. )
  121. def seed_worker(worker_id: int, num_workers: int, rank: int):
  122. """
  123. Helper function to set worker seed during Dataloader initialization.
  124. """
  125. init_seed = torch.initial_seed() % 2**32
  126. worker_seed = num_workers * rank + init_seed
  127. set_seed(worker_seed)
  128. def enable_full_determinism(seed: int, warn_only: bool = False):
  129. """
  130. Helper function for reproducible behavior during distributed training. See
  131. https://pytorch.org/docs/stable/notes/randomness.html for pytorch
  132. """
  133. # set seed first
  134. set_seed(seed)
  135. if is_torch_available():
  136. # Enable PyTorch deterministic mode. This potentially requires either the environment
  137. # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
  138. # depending on the CUDA version, so we set them both here
  139. os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
  140. os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
  141. # The environment variable required to enable deterministic mode on Ascend NPUs.
  142. os.environ["ASCEND_LAUNCH_BLOCKING"] = "1"
  143. os.environ["HCCL_DETERMINISTIC"] = "1"
  144. os.environ["FLASH_ATTENTION_DETERMINISTIC"] = "1"
  145. torch.use_deterministic_algorithms(True, warn_only=warn_only)
  146. # Enable CUDNN deterministic mode
  147. torch.backends.cudnn.deterministic = True
  148. torch.backends.cudnn.benchmark = False
  149. def set_seed(seed: int, deterministic: bool = False):
  150. """
  151. Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch` (if installed).
  152. Args:
  153. seed (`int`):
  154. The seed to set.
  155. deterministic (`bool`, *optional*, defaults to `False`):
  156. Whether to use deterministic algorithms where available. Can slow down training.
  157. """
  158. random.seed(seed)
  159. np.random.seed(seed)
  160. if is_torch_available():
  161. torch.manual_seed(seed)
  162. torch.cuda.manual_seed_all(seed)
  163. # ^^ safe to call this function even if cuda is not available
  164. if deterministic:
  165. torch.use_deterministic_algorithms(True)
  166. if is_torch_mlu_available():
  167. torch.mlu.manual_seed_all(seed)
  168. if is_torch_musa_available():
  169. torch.musa.manual_seed_all(seed)
  170. if is_torch_npu_available():
  171. torch.npu.manual_seed_all(seed)
  172. if is_torch_hpu_available():
  173. torch.hpu.manual_seed_all(seed)
  174. if is_torch_xpu_available():
  175. torch.xpu.manual_seed_all(seed)
  176. class EvalPrediction:
  177. """
  178. Evaluation output (always contains labels), to be used to compute metrics.
  179. Parameters:
  180. predictions (`np.ndarray`): Predictions of the model.
  181. label_ids (`np.ndarray`): Targets to be matched.
  182. inputs (`np.ndarray`, *optional*): Input data passed to the model.
  183. losses (`np.ndarray`, *optional*): Loss values computed during evaluation.
  184. """
  185. def __init__(
  186. self,
  187. predictions: np.ndarray | tuple[np.ndarray],
  188. label_ids: np.ndarray | tuple[np.ndarray],
  189. inputs: np.ndarray | tuple[np.ndarray] | None = None,
  190. losses: np.ndarray | tuple[np.ndarray] | None = None,
  191. ):
  192. self.predictions = predictions
  193. self.label_ids = label_ids
  194. self.inputs = inputs
  195. self.losses = losses
  196. self.elements = (self.predictions, self.label_ids)
  197. if self.inputs is not None:
  198. self.elements += (self.inputs,)
  199. if self.losses is not None:
  200. self.elements += (self.losses,)
  201. def __iter__(self):
  202. return iter(self.elements)
  203. def __getitem__(self, idx):
  204. if idx < 0 or idx >= len(self.elements):
  205. raise IndexError("tuple index out of range")
  206. return self.elements[idx]
  207. class EvalLoopOutput(NamedTuple):
  208. predictions: np.ndarray | tuple[np.ndarray]
  209. label_ids: np.ndarray | tuple[np.ndarray] | None
  210. metrics: dict[str, float] | None
  211. num_samples: int | None
  212. class PredictionOutput(NamedTuple):
  213. predictions: np.ndarray | tuple[np.ndarray]
  214. label_ids: np.ndarray | tuple[np.ndarray] | None
  215. metrics: dict[str, float] | None
  216. class TrainOutput(NamedTuple):
  217. global_step: int
  218. training_loss: float
  219. metrics: dict[str, float]
  220. PREFIX_CHECKPOINT_DIR = "checkpoint"
  221. _re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$")
  222. def get_last_checkpoint(folder):
  223. content = os.listdir(folder)
  224. checkpoints = [
  225. path
  226. for path in content
  227. if _re_checkpoint.search(path) is not None and os.path.isdir(os.path.join(folder, path))
  228. ]
  229. if len(checkpoints) == 0:
  230. return
  231. return os.path.join(folder, max(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0])))
  232. def sort_checkpoints(
  233. output_dir: str,
  234. checkpoint_prefix: str = PREFIX_CHECKPOINT_DIR,
  235. use_mtime: bool = False,
  236. best_model_checkpoint: str | None = None,
  237. ) -> list[str]:
  238. """
  239. Return checkpoint directories sorted by step number (oldest first).
  240. Args:
  241. output_dir (`str`):
  242. The directory containing the checkpoints.
  243. checkpoint_prefix (`str`, *optional*, defaults to `"checkpoint"`):
  244. The prefix used for checkpoint directory names.
  245. use_mtime (`bool`, *optional*, defaults to `False`):
  246. Whether to sort by modification time instead of step number.
  247. best_model_checkpoint (`str`, *optional*):
  248. If provided, this checkpoint is moved to second-to-last position to protect
  249. it from deletion while keeping the most recent checkpoint last for resuming.
  250. Returns:
  251. `list[str]`: Sorted list of checkpoint directory paths (oldest first).
  252. """
  253. glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
  254. ordering_and_checkpoint_path = []
  255. for path in glob_checkpoints:
  256. if use_mtime:
  257. ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
  258. else:
  259. regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
  260. if regex_match is not None and regex_match.groups() is not None:
  261. ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
  262. checkpoints_sorted = sorted(ordering_and_checkpoint_path)
  263. # mtime is not reliable on some filesystems (e.g., cloud fuse filesystems)
  264. # so we check if the mtime is fake and fall back to numerical ordering
  265. if use_mtime and len(checkpoints_sorted) > 1:
  266. mtime_diff = checkpoints_sorted[-1][0] - checkpoints_sorted[0][0]
  267. if mtime_diff < 1.0:
  268. logger.warning_once("mtime may not be reliable on this filesystem, falling back to numerical ordering")
  269. return sort_checkpoints(
  270. output_dir, checkpoint_prefix, use_mtime=False, best_model_checkpoint=best_model_checkpoint
  271. )
  272. checkpoints_sorted = [path for _, path in checkpoints_sorted]
  273. # Move best_model_checkpoint to second-to-last position to protect it from deletion
  274. # while keeping the most recent checkpoint at the end for resuming training.
  275. if best_model_checkpoint is not None:
  276. best_model_checkpoint = str(Path(best_model_checkpoint))
  277. if best_model_checkpoint in checkpoints_sorted and checkpoints_sorted[-1] != best_model_checkpoint:
  278. most_recent = checkpoints_sorted[-1]
  279. checkpoints_sorted = [c for c in checkpoints_sorted if c not in {best_model_checkpoint, most_recent}]
  280. checkpoints_sorted += [best_model_checkpoint, most_recent]
  281. return checkpoints_sorted
  282. def rotate_checkpoints(
  283. output_dir: str,
  284. save_total_limit: int | None = None,
  285. best_model_checkpoint: str | None = None,
  286. use_mtime: bool = False,
  287. checkpoint_prefix: str = PREFIX_CHECKPOINT_DIR,
  288. ) -> None:
  289. """
  290. Delete older checkpoints, keeping at most `save_total_limit`.
  291. Always preserves the most recent checkpoint and the best model checkpoint (if provided).
  292. Args:
  293. output_dir (`str`):
  294. The directory containing the checkpoints.
  295. save_total_limit (`int`, *optional*):
  296. Maximum number of checkpoints to keep. No deletion if `None` or <= 0.
  297. best_model_checkpoint (`str`, *optional*):
  298. Path to best checkpoint (will always be preserved).
  299. use_mtime (`bool`, *optional*, defaults to `False`):
  300. Whether to sort by modification time instead of step number.
  301. checkpoint_prefix (`str`, *optional*, defaults to `"checkpoint"`):
  302. The prefix used for checkpoint directory names.
  303. """
  304. if save_total_limit is None or save_total_limit <= 0:
  305. return
  306. checkpoints = sort_checkpoints(output_dir, checkpoint_prefix, use_mtime)
  307. if len(checkpoints) <= save_total_limit:
  308. return
  309. # Checkpoints that must not be deleted
  310. protected = {checkpoints[-1]} # most recent, for resuming
  311. if best_model_checkpoint is not None:
  312. protected.add(str(Path(best_model_checkpoint)))
  313. # Delete oldest non-protected checkpoints until we have save_total_limit left
  314. num_to_keep = max(save_total_limit, len(protected))
  315. remaining = len(checkpoints)
  316. for checkpoint in checkpoints:
  317. if remaining <= num_to_keep:
  318. break
  319. if checkpoint not in protected:
  320. shutil.rmtree(checkpoint, ignore_errors=True)
  321. remaining -= 1
  322. class IntervalStrategy(ExplicitEnum):
  323. NO = "no"
  324. STEPS = "steps"
  325. EPOCH = "epoch"
  326. class SaveStrategy(ExplicitEnum):
  327. NO = "no"
  328. STEPS = "steps"
  329. EPOCH = "epoch"
  330. BEST = "best"
  331. class HubStrategy(ExplicitEnum):
  332. END = "end"
  333. EVERY_SAVE = "every_save"
  334. CHECKPOINT = "checkpoint"
  335. ALL_CHECKPOINTS = "all_checkpoints"
  336. class BestRun(NamedTuple):
  337. """
  338. The best run found by a hyperparameter search (see [`~Trainer.hyperparameter_search`]).
  339. Parameters:
  340. run_id (`str`):
  341. The id of the best run (if models were saved, the corresponding checkpoint will be in the folder ending
  342. with run-{run_id}).
  343. objective (`float`):
  344. The objective that was obtained for this run.
  345. hyperparameters (`dict[str, Any]`):
  346. The hyperparameters picked to get this run.
  347. run_summary (`Optional[Any]`):
  348. A summary of tuning experiments. `ray.tune.ExperimentAnalysis` object for Ray backend.
  349. """
  350. run_id: str
  351. objective: float | list[float]
  352. hyperparameters: dict[str, Any]
  353. run_summary: Any | None = None
  354. def default_compute_objective(metrics: dict[str, float]) -> float:
  355. """
  356. The default objective to maximize/minimize when doing an hyperparameter search. It is the evaluation loss if no
  357. metrics are provided to the [`Trainer`], the sum of all metrics otherwise.
  358. Args:
  359. metrics (`dict[str, float]`): The metrics returned by the evaluate method.
  360. Return:
  361. `float`: The objective to minimize or maximize
  362. """
  363. metrics = copy.deepcopy(metrics)
  364. loss = metrics.pop("eval_loss", None)
  365. _ = metrics.pop("epoch", None)
  366. # Remove speed metrics
  367. speed_metrics = [m for m in metrics if m.endswith("_runtime") or m.endswith("_per_second")]
  368. for sm in speed_metrics:
  369. _ = metrics.pop(sm, None)
  370. return loss if len(metrics) == 0 else sum(metrics.values())
  371. def default_hp_space_optuna(trial) -> dict[str, float]:
  372. from .integrations import is_optuna_available
  373. assert is_optuna_available(), "This function needs Optuna installed: `pip install optuna`"
  374. return {
  375. "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True),
  376. "num_train_epochs": trial.suggest_int("num_train_epochs", 1, 5),
  377. "seed": trial.suggest_int("seed", 1, 40),
  378. "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [4, 8, 16, 32, 64]),
  379. }
  380. def default_hp_space_ray(trial) -> dict[str, Any]:
  381. from .integrations import is_ray_tune_available
  382. assert is_ray_tune_available(), "This function needs ray installed: `pip install ray[tune]`"
  383. from ray import tune
  384. return {
  385. "learning_rate": tune.loguniform(1e-6, 1e-4),
  386. "num_train_epochs": tune.choice(list(range(1, 6))),
  387. "seed": tune.uniform(1, 40),
  388. "per_device_train_batch_size": tune.choice([4, 8, 16, 32, 64]),
  389. }
  390. def default_hp_space_wandb(trial) -> dict[str, Any]:
  391. from .integrations import is_wandb_available
  392. if not is_wandb_available():
  393. raise ImportError("This function needs wandb installed: `pip install wandb`")
  394. return {
  395. "method": "random",
  396. "metric": {"name": "objective", "goal": "minimize"},
  397. "parameters": {
  398. "learning_rate": {"distribution": "uniform", "min": 1e-6, "max": 1e-4},
  399. "num_train_epochs": {"distribution": "int_uniform", "min": 1, "max": 6},
  400. "seed": {"distribution": "int_uniform", "min": 1, "max": 40},
  401. "per_device_train_batch_size": {"values": [4, 8, 16, 32, 64]},
  402. },
  403. }
  404. class HPSearchBackend(ExplicitEnum):
  405. OPTUNA = "optuna"
  406. RAY = "ray"
  407. WANDB = "wandb"
  408. def is_main_process(local_rank):
  409. """
  410. Whether or not the current process is the local process, based on `xr.global_ordinal()` (for TPUs) first, then on
  411. `local_rank`.
  412. """
  413. if is_torch_xla_available():
  414. import torch_xla.runtime as xr
  415. return xr.global_ordinal() == 0
  416. return local_rank in [-1, 0]
  417. def total_processes_number(local_rank):
  418. """
  419. Return the number of processes launched in parallel. Works with `torch.distributed` and TPUs.
  420. """
  421. if is_torch_xla_available():
  422. import torch_xla.runtime as xr
  423. return xr.world_size()
  424. elif local_rank != -1 and is_torch_available():
  425. import torch
  426. return torch.distributed.get_world_size()
  427. return 1
  428. def speed_metrics(split, start_time, num_samples=None, num_steps=None, num_tokens=None):
  429. """
  430. Measure and return speed performance metrics.
  431. This function requires a time snapshot `start_time` before the operation to be measured starts and this function
  432. should be run immediately after the operation to be measured has completed.
  433. Args:
  434. - split: name to prefix metric (like train, eval, test...)
  435. - start_time: operation start time
  436. - num_samples: number of samples processed
  437. - num_steps: number of steps processed
  438. - num_tokens: number of tokens processed
  439. """
  440. runtime = time.time() - start_time
  441. result = {f"{split}_runtime": round(runtime, 4)}
  442. if runtime == 0:
  443. return result
  444. if num_samples is not None:
  445. samples_per_second = num_samples / runtime
  446. result[f"{split}_samples_per_second"] = round(samples_per_second, 3)
  447. if num_steps is not None:
  448. steps_per_second = num_steps / runtime
  449. result[f"{split}_steps_per_second"] = round(steps_per_second, 3)
  450. if num_tokens is not None:
  451. tokens_per_second = num_tokens / runtime
  452. result[f"{split}_tokens_per_second"] = round(tokens_per_second, 3)
  453. return result
  454. class SchedulerType(ExplicitEnum):
  455. """
  456. Scheduler names for the parameter `lr_scheduler_type` in [`TrainingArguments`].
  457. By default, it uses "linear". Internally, this retrieves `get_linear_schedule_with_warmup` scheduler from [`Trainer`].
  458. Scheduler types:
  459. - "linear" = [`get_linear_schedule_with_warmup`]
  460. - "cosine" = [`get_cosine_schedule_with_warmup`]
  461. - "cosine_with_restarts" = [`get_cosine_with_hard_restarts_schedule_with_warmup`]
  462. - "polynomial" = [`get_polynomial_decay_schedule_with_warmup`]
  463. - "constant" = [`get_constant_schedule`]
  464. - "constant_with_warmup" = [`get_constant_schedule_with_warmup`]
  465. - "inverse_sqrt" = [`get_inverse_sqrt_schedule`]
  466. - "reduce_lr_on_plateau" = [`get_reduce_on_plateau_schedule`]
  467. - "cosine_with_min_lr" = [`get_cosine_with_min_lr_schedule_with_warmup`]
  468. - "cosine_warmup_with_min_lr" = [`get_cosine_with_min_lr_schedule_with_warmup_lr_rate`]
  469. - "warmup_stable_decay" = [`get_wsd_schedule`]
  470. - "greedy" = [`get_greedy_schedule`]
  471. """
  472. LINEAR = "linear"
  473. COSINE = "cosine"
  474. COSINE_WITH_RESTARTS = "cosine_with_restarts"
  475. POLYNOMIAL = "polynomial"
  476. CONSTANT = "constant"
  477. CONSTANT_WITH_WARMUP = "constant_with_warmup"
  478. INVERSE_SQRT = "inverse_sqrt"
  479. REDUCE_ON_PLATEAU = "reduce_lr_on_plateau"
  480. COSINE_WITH_MIN_LR = "cosine_with_min_lr"
  481. COSINE_WARMUP_WITH_MIN_LR = "cosine_warmup_with_min_lr"
  482. WARMUP_STABLE_DECAY = "warmup_stable_decay"
  483. GREEDY = "greedy"
  484. class TrainerMemoryTracker:
  485. """
  486. A helper class that tracks cpu and gpu memory.
  487. This class will silently skip unless `psutil` is available. Install with `pip install psutil`.
  488. When a stage completes, it can pass metrics dict to update with the memory metrics gathered during this stage.
  489. Example :
  490. ```python
  491. self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
  492. self._memory_tracker.start()
  493. # code ...
  494. metrics = {"train_runtime": 10.5}
  495. self._memory_tracker.stop_and_update_metrics(metrics)
  496. ```
  497. To understand this class' intricacies please read the documentation of [`~Trainer.log_metrics`].
  498. """
  499. # map trainer methods to metrics prefix
  500. stages = {
  501. "__init__": "init",
  502. "train": "train",
  503. "_inner_training_loop": "train",
  504. "_finalize_training": "train",
  505. "evaluate": "eval",
  506. "predict": "test",
  507. }
  508. def __init__(self, skip_memory_metrics=False):
  509. self.skip_memory_metrics = skip_memory_metrics
  510. if not is_psutil_available():
  511. # soft dependency on psutil
  512. self.skip_memory_metrics = True
  513. if self.skip_memory_metrics:
  514. return
  515. import psutil
  516. if is_torch_cuda_available() or is_torch_mlu_available() or is_torch_musa_available():
  517. import torch
  518. self.torch = torch
  519. self.gpu = {}
  520. elif is_torch_mps_available():
  521. import torch
  522. self.torch = torch
  523. self.gpu = {}
  524. elif is_torch_xpu_available():
  525. import torch
  526. self.torch = torch
  527. self.gpu = {}
  528. elif is_torch_npu_available():
  529. import torch
  530. self.torch = torch
  531. self.gpu = {}
  532. elif is_torch_hpu_available():
  533. import torch
  534. self.torch = torch
  535. self.gpu = {}
  536. else:
  537. self.torch = None
  538. self.process = psutil.Process()
  539. self.cur_stage = None
  540. self.cpu = {}
  541. self.init_reported = False
  542. def derive_stage(self):
  543. """derives the stage/caller name automatically"""
  544. caller = inspect.currentframe().f_back.f_back.f_code.co_name
  545. if caller in self.stages:
  546. return self.stages[caller]
  547. else:
  548. raise ValueError(
  549. f"was called from {caller}, but only expect to be called from one of {self.stages.keys()}"
  550. )
  551. def cpu_mem_used(self):
  552. """get resident set size memory for the current process"""
  553. return self.process.memory_info().rss
  554. def peak_monitor_func(self):
  555. self.cpu_mem_used_peak = -1
  556. while True:
  557. self.cpu_mem_used_peak = max(self.cpu_mem_used(), self.cpu_mem_used_peak)
  558. # can't sleep or will not catch the peak right (this comment is here on purpose)
  559. # time.sleep(0.001) # 1msec
  560. if not self.peak_monitoring:
  561. break
  562. def start(self):
  563. """start tracking for the caller's stage"""
  564. if self.skip_memory_metrics:
  565. return
  566. stage = self.derive_stage()
  567. # deal with nested calls of eval during train - simply ignore those
  568. if self.cur_stage is not None and self.cur_stage != stage:
  569. return
  570. self.cur_stage = stage
  571. gc.collect()
  572. if self.torch is not None:
  573. if torch.cuda.is_available():
  574. self.torch.cuda.reset_peak_memory_stats()
  575. self.torch.cuda.empty_cache()
  576. elif is_torch_mlu_available():
  577. self.torch.mlu.reset_peak_memory_stats()
  578. self.torch.mlu.empty_cache()
  579. elif is_torch_musa_available():
  580. self.torch.musa.reset_peak_memory_stats()
  581. self.torch.musa.empty_cache()
  582. elif is_torch_xpu_available():
  583. self.torch.xpu.reset_peak_memory_stats()
  584. self.torch.xpu.empty_cache()
  585. elif is_torch_npu_available():
  586. self.torch.npu.reset_peak_memory_stats()
  587. self.torch.npu.empty_cache()
  588. elif is_torch_hpu_available():
  589. self.torch.hpu.reset_peak_memory_stats()
  590. # not available on hpu as it reserves all device memory for the current process
  591. # self.torch.hpu.empty_cache()
  592. elif is_torch_mps_available():
  593. self.torch.mps.empty_cache()
  594. # gpu
  595. if self.torch is not None:
  596. if torch.cuda.is_available():
  597. self.gpu_mem_used_at_start = self.torch.cuda.memory_allocated()
  598. elif is_torch_mlu_available():
  599. self.gpu_mem_used_at_start = self.torch.mlu.memory_allocated()
  600. elif is_torch_musa_available():
  601. self.gpu_mem_used_at_start = self.torch.musa.memory_allocated()
  602. elif is_torch_xpu_available():
  603. self.gpu_mem_used_at_start = self.torch.xpu.memory_allocated()
  604. elif is_torch_npu_available():
  605. self.gpu_mem_used_at_start = self.torch.npu.memory_allocated()
  606. elif is_torch_hpu_available():
  607. self.gpu_mem_used_at_start = self.torch.hpu.memory_allocated()
  608. elif is_torch_mps_available():
  609. self.gpu_mem_used_at_start = self.torch.mps.current_allocated_memory()
  610. # cpu
  611. self.cpu_mem_used_at_start = self.cpu_mem_used()
  612. self.peak_monitoring = True
  613. peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
  614. peak_monitor_thread.daemon = True
  615. peak_monitor_thread.start()
  616. def stop(self, stage):
  617. """stop tracking for the passed stage"""
  618. # deal with nested calls of eval during train - simply ignore those
  619. if self.cur_stage is not None and self.cur_stage != stage:
  620. return
  621. # this sends a signal to peak_monitor_func to complete its loop
  622. self.peak_monitoring = False
  623. # first ensure all objects get collected and their memory is freed
  624. gc.collect()
  625. if self.torch is not None:
  626. if torch.cuda.is_available():
  627. self.torch.cuda.empty_cache()
  628. elif is_torch_mlu_available():
  629. self.torch.mlu.empty_cache()
  630. elif is_torch_musa_available():
  631. self.torch.musa.empty_cache()
  632. elif is_torch_xpu_available():
  633. self.torch.xpu.empty_cache()
  634. elif is_torch_npu_available():
  635. self.torch.npu.empty_cache()
  636. elif is_torch_hpu_available():
  637. # not available on hpu as it reserves all device memory for the current process
  638. # self.torch.npu.empty_cache()
  639. pass
  640. elif is_torch_mps_available():
  641. self.torch.mps.empty_cache()
  642. # concepts:
  643. # - alloc_delta: the difference of allocated memory between the end and the start
  644. # - peaked_delta: the difference between the peak memory and the current memory
  645. # in order to know how much memory the measured code consumed one needs to sum these two
  646. # gpu
  647. if self.torch is not None:
  648. if torch.cuda.is_available():
  649. self.gpu_mem_used_now = self.torch.cuda.memory_allocated()
  650. self.gpu_mem_used_peak = self.torch.cuda.max_memory_allocated()
  651. elif is_torch_mlu_available():
  652. self.gpu_mem_used_now = self.torch.mlu.memory_allocated()
  653. self.gpu_mem_used_peak = self.torch.mlu.max_memory_allocated()
  654. elif is_torch_musa_available():
  655. self.gpu_mem_used_now = self.torch.musa.memory_allocated()
  656. self.gpu_mem_used_peak = self.torch.musa.max_memory_allocated()
  657. elif is_torch_xpu_available():
  658. self.gpu_mem_used_now = self.torch.xpu.memory_allocated()
  659. self.gpu_mem_used_peak = self.torch.xpu.max_memory_allocated()
  660. elif is_torch_npu_available():
  661. self.gpu_mem_used_now = self.torch.npu.memory_allocated()
  662. self.gpu_mem_used_peak = self.torch.npu.max_memory_allocated()
  663. elif is_torch_hpu_available():
  664. self.gpu_mem_used_now = self.torch.hpu.memory_allocated()
  665. self.gpu_mem_used_peak = self.torch.hpu.max_memory_allocated()
  666. elif is_torch_mps_available():
  667. self.gpu_mem_used_now = self.torch.mps.current_allocated_memory()
  668. # self.torch.mps.max_memory_allocated() does not exist yet
  669. self.gpu_mem_used_peak = None
  670. else:
  671. raise ValueError("No available GPU device found!")
  672. self.gpu[self.cur_stage] = {
  673. "begin": self.gpu_mem_used_at_start,
  674. "end": self.gpu_mem_used_now,
  675. "alloc": (self.gpu_mem_used_now - self.gpu_mem_used_at_start),
  676. }
  677. if self.gpu_mem_used_peak is not None:
  678. self.gpu[self.cur_stage]["peaked"] = max(0, self.gpu_mem_used_peak - self.gpu_mem_used_now)
  679. else:
  680. self.gpu[self.cur_stage]["peaked"] = "Not available"
  681. # cpu
  682. self.cpu_mem_used_now = self.cpu_mem_used()
  683. self.cpu[self.cur_stage] = {
  684. "begin": self.cpu_mem_used_at_start,
  685. "end": self.cpu_mem_used_now,
  686. "alloc": (self.cpu_mem_used_now - self.cpu_mem_used_at_start),
  687. "peaked": max(0, self.cpu_mem_used_peak - self.cpu_mem_used_now),
  688. }
  689. # reset - cycle finished
  690. self.cur_stage = None
  691. def update_metrics(self, stage, metrics):
  692. """updates the metrics"""
  693. if self.skip_memory_metrics:
  694. return
  695. # deal with nested calls of eval during train - simply ignore those
  696. if self.cur_stage is not None and self.cur_stage != stage:
  697. return
  698. # since we don't have a way to return init metrics, we push them into the first of train/val/predict
  699. stages = [stage]
  700. if not self.init_reported:
  701. stages.insert(0, "init")
  702. self.init_reported = True
  703. for stage in stages:
  704. for t in ["alloc", "peaked"]:
  705. if stage in self.cpu and t in self.cpu[stage]:
  706. metrics[f"{stage}_mem_cpu_{t}_delta"] = self.cpu[stage][t]
  707. if self.torch is not None and stage in self.gpu and t in self.gpu[stage]:
  708. metrics[f"{stage}_mem_gpu_{t}_delta"] = self.gpu[stage][t]
  709. # if we need additional debug info, enable the following
  710. # for t in ["begin", "end"]:
  711. # if stage in self.cpu and t in self.cpu[stage]:
  712. # metrics[f"{stage}_mem_cpu_{t}"] = self.cpu[stage][t]
  713. # if self.torch is not None and stage in self.gpu and t in self.gpu[stage]:
  714. # metrics[f"{stage}_mem_gpu_{t}"] = self.gpu[stage][t]
  715. # since memory can be allocated before init, and it might be difficult to track overall
  716. # memory usage, in particular for GPU, let's report memory usage at the point init was called
  717. if stages[0] == "init":
  718. metrics["before_init_mem_cpu"] = self.cpu["init"]["begin"]
  719. if self.torch is not None:
  720. metrics["before_init_mem_gpu"] = self.gpu["init"]["begin"]
  721. # if we also wanted to report any additional memory allocations in between init and
  722. # whatever the next stage was we could also report this:
  723. # if self.cpu["init"]["end"] != self.cpu[stage]["begin"]:
  724. # metrics[f"after_init_mem_cpu_delta"] = self.cpu[stage]["begin"] - self.cpu["init"]["end"]
  725. # if self.torch is not None and self.gpu["init"]["end"] != self.gpu[stage]["begin"]:
  726. # metrics[f"after_init_mem_gpu_delta"] = self.gpu[stage]["begin"] - self.gpu["init"]["end"]
  727. def stop_and_update_metrics(self, metrics=None):
  728. """combine stop and metrics update in one call for simpler code"""
  729. if self.skip_memory_metrics:
  730. return
  731. stage = self.derive_stage()
  732. self.stop(stage)
  733. # init doesn't have metrics to update so we just save that data for later stages to retrieve
  734. if metrics is not None:
  735. self.update_metrics(stage, metrics)
  736. def has_length(dataset: Any) -> TypeGuard[Sized]:
  737. """
  738. Checks if the dataset implements __len__() and it doesn't raise an error
  739. """
  740. try:
  741. return len(dataset) is not None
  742. except TypeError:
  743. # TypeError: len() of unsized object
  744. return False
  745. except AttributeError:
  746. # Ray DataSets raises an AttributeError: https://github.com/ray-project/ray/blob/master/python/ray/data/dataset.py#L5616
  747. return False
  748. def denumpify_detensorize(metrics):
  749. """
  750. Recursively calls `.item()` on the element of the dictionary passed
  751. """
  752. if isinstance(metrics, (list, tuple)):
  753. return type(metrics)(denumpify_detensorize(m) for m in metrics)
  754. elif isinstance(metrics, dict):
  755. return type(metrics)({k: denumpify_detensorize(v) for k, v in metrics.items()})
  756. elif isinstance(metrics, np.generic):
  757. return metrics.item()
  758. elif is_torch_available() and isinstance(metrics, torch.Tensor) and metrics.numel() == 1:
  759. return metrics.item()
  760. return metrics
  761. def number_of_arguments(func):
  762. """
  763. Return the number of arguments of the passed function, even if it's a partial function.
  764. """
  765. if isinstance(func, functools.partial):
  766. total_args = len(inspect.signature(func.func).parameters)
  767. return total_args - len(func.args) - len(func.keywords)
  768. return len(inspect.signature(func).parameters)
  769. def find_executable_batch_size(
  770. function: Callable | None = None, starting_batch_size: int = 128, auto_find_batch_size: bool = False
  771. ):
  772. """
  773. Args:
  774. A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or
  775. CUDNN, the batch size is multiplied by 0.9 and passed to `function`. `function` must take in a `batch_size` parameter as
  776. its first argument.
  777. function (`Callable`, *optional*)
  778. A function to wrap
  779. starting_batch_size (`int`, *optional*)
  780. The batch size to try and fit into memory
  781. auto_find_batch_size (`bool`, *optional*)
  782. If False, will just execute `function`
  783. """
  784. if function is None:
  785. return functools.partial(
  786. find_executable_batch_size,
  787. starting_batch_size=starting_batch_size,
  788. auto_find_batch_size=auto_find_batch_size,
  789. )
  790. if auto_find_batch_size:
  791. requires_backends(find_executable_batch_size, "accelerate")
  792. from accelerate.utils import find_executable_batch_size as accelerate_find_executable_batch_size
  793. return accelerate_find_executable_batch_size(function=function, starting_batch_size=starting_batch_size)
  794. return functools.partial(function, batch_size=starting_batch_size)
  795. class FSDPOption(ExplicitEnum):
  796. FULL_SHARD = "full_shard"
  797. SHARD_GRAD_OP = "shard_grad_op"
  798. NO_SHARD = "no_shard"
  799. HYBRID_SHARD = "hybrid_shard"
  800. HYBRID_SHARD_ZERO2 = "hybrid_shard_zero2"
  801. OFFLOAD = "offload"
  802. AUTO_WRAP = "auto_wrap"
  803. class RemoveColumnsCollator:
  804. """Wrap the data collator to remove unused columns before they are passed to the collator."""
  805. def __init__(
  806. self,
  807. data_collator,
  808. signature_columns,
  809. logger=None,
  810. model_name: str | None = None,
  811. description: str | None = None,
  812. ):
  813. self.data_collator = data_collator
  814. self.signature_columns = signature_columns
  815. self.logger = logger
  816. self.description = description
  817. self.model_name = model_name
  818. self.message_logged = False
  819. def _remove_columns(self, feature: dict) -> dict:
  820. if not isinstance(feature, dict):
  821. return feature
  822. if not self.message_logged and self.logger and self.model_name:
  823. ignored_columns = list(set(feature.keys()) - set(self.signature_columns))
  824. if len(ignored_columns) > 0:
  825. dset_description = "" if self.description is None else f"in the {self.description} set"
  826. self.logger.info(
  827. f"The following columns {dset_description} don't have a corresponding argument in "
  828. f"`{self.model_name}.forward` and have been ignored: {', '.join(ignored_columns)}."
  829. f" If {', '.join(ignored_columns)} are not expected by `{self.model_name}.forward`, "
  830. " you can safely ignore this message."
  831. )
  832. self.message_logged = True
  833. return {k: v for k, v in feature.items() if k in self.signature_columns}
  834. def __call__(self, features: list[dict]):
  835. features = [self._remove_columns(feature) for feature in features]
  836. return self.data_collator(features)
  837. def check_target_module_exists(optim_target_modules, key: str, return_is_regex: bool = False):
  838. """A helper method to check if the passed module's key name matches any of the target modules in the optim_target_modules.
  839. Args:
  840. optim_target_modules (`Union[str, list[str]]`):
  841. A list of strings to try to match. Can be also a full string.
  842. key (`str`):
  843. A key to search any matches in optim_target_modules
  844. return_is_regex (`bool`):
  845. If set to `True`, the method will return whether the passed `optim_target_modules`
  846. is a regex or not.
  847. Returns:
  848. `bool` : True of match object if key matches any target modules from config, False or
  849. None if no match found
  850. `bool` : If the matched target module is a regex to silence out the warnings in Trainer
  851. for extra modules being found (only if `target_module_found=True` for an array of regex).
  852. """
  853. target_module_found = False
  854. is_regex = False
  855. if isinstance(optim_target_modules, str):
  856. target_module_found = bool(re.fullmatch(optim_target_modules, key))
  857. is_regex = optim_target_modules != key
  858. elif key in optim_target_modules: # from here, target_module_found must be a list of str
  859. # this module is specified directly in target_modules
  860. target_module_found = True
  861. elif any(target_key in key for target_key in optim_target_modules):
  862. target_module_found = True
  863. elif any(bool(re.fullmatch(optim_target_module, key)) for optim_target_module in optim_target_modules):
  864. target_module_found = True
  865. is_regex = True
  866. if return_is_regex:
  867. return target_module_found, is_regex
  868. return target_module_found
  869. def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
  870. """
  871. This is the same as
  872. [`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict)
  873. but for a sharded checkpoint.
  874. This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
  875. loaded in the model.
  876. Args:
  877. model (`torch.nn.Module`): The model in which to load the checkpoint.
  878. folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint.
  879. strict (`bool`, *optional*, defaults to `True`):
  880. Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
  881. prefer_safe (`bool`, *optional*, defaults to `True`):
  882. If both safetensors and PyTorch save files are present in checkpoint and `prefer_safe` is True, the
  883. safetensors files will be loaded. Otherwise, PyTorch files are always loaded when possible.
  884. Returns:
  885. `NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields
  886. - `missing_keys` is a list of str containing the missing keys
  887. - `unexpected_keys` is a list of str containing the unexpected keys
  888. """
  889. # Load the index
  890. index_file = os.path.join(folder, WEIGHTS_INDEX_NAME)
  891. safe_index_file = os.path.join(folder, SAFE_WEIGHTS_INDEX_NAME)
  892. index_present = os.path.isfile(index_file)
  893. safe_index_present = os.path.isfile(safe_index_file)
  894. if not index_present and not safe_index_present:
  895. filenames = (WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME)
  896. raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {folder}.")
  897. load_safe = safe_index_present and (prefer_safe or not index_present)
  898. load_index = safe_index_file if load_safe else index_file
  899. with open(load_index, "r", encoding="utf-8") as f:
  900. index = json.load(f)
  901. shard_files = list(set(index["weight_map"].values()))
  902. # If strict=True, error before loading any of the state dicts.
  903. # TODO: Here, update the weight map with the config.dynamic_weight_conversion
  904. loaded_keys = index["weight_map"].keys()
  905. model_keys = model.state_dict().keys()
  906. missing_keys = [key for key in model_keys if key not in loaded_keys]
  907. unexpected_keys = [key for key in loaded_keys if key not in model_keys]
  908. if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0):
  909. error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
  910. if len(missing_keys) > 0:
  911. str_missing_keys = ",".join([f'"{k}"' for k in missing_keys])
  912. error_message += f"\nMissing key(s): {str_missing_keys}."
  913. if len(unexpected_keys) > 0:
  914. str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys])
  915. error_message += f"\nUnexpected key(s): {str_unexpected_keys}."
  916. raise RuntimeError(error_message)
  917. if load_safe:
  918. loader = safe_load_file
  919. else:
  920. check_torch_load_is_safe()
  921. loader = partial(torch.load, map_location="cpu", weights_only=True)
  922. for shard_file in shard_files:
  923. state_dict = loader(os.path.join(folder, shard_file))
  924. model.load_state_dict(state_dict, strict=False)
  925. # Make sure memory is freed before we load the next state dict.
  926. del state_dict
  927. gc.collect()
  928. # Return the same thing as PyTorch load_state_dict function.
  929. return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
  930. def compare_trainer_and_checkpoint_args(training_args, trainer_state):
  931. """
  932. Compare training arguments with those stored in a checkpoint's trainer state.
  933. Logs a warning if there are mismatches between the current training arguments
  934. and the ones saved in the checkpoint.
  935. Args:
  936. training_args: The current training arguments.
  937. trainer_state: The trainer state loaded from a checkpoint.
  938. """
  939. attributes_map = {
  940. "logging_steps": "logging_steps",
  941. "eval_steps": "eval_steps",
  942. "save_steps": "save_steps",
  943. }
  944. has_warning = False
  945. warning_str = "Warning: The following arguments do not match the ones in the `trainer_state.json` within the checkpoint directory: "
  946. for arg_attr, state_attr in attributes_map.items():
  947. arg_value = getattr(training_args, arg_attr, None)
  948. state_value = getattr(trainer_state, state_attr, None)
  949. if arg_value is not None and state_value is not None and arg_value != state_value:
  950. warning_str += f"\n\t{arg_attr}: {arg_value} (from args) != {state_value} (from trainer_state.json)"
  951. has_warning = True
  952. # train bs is special as we need to account for multi-GPU
  953. train_bs_args = training_args.per_device_train_batch_size
  954. train_bs_state = trainer_state.train_batch_size // max(1, training_args.n_gpu)
  955. if train_bs_args != train_bs_state:
  956. warning_str += f"\n\tper_device_train_batch_size: {train_bs_args} (from args) != {train_bs_state} (from trainer_state.json)"
  957. has_warning = True
  958. if has_warning:
  959. logger.warning_once(warning_str)
  960. def align_special_tokens(model, processing_class):
  961. """
  962. Aligns the special tokens of the tokenizer with the model configs.
  963. A new tokens may be defined in the tokenizer for fine-tuning purposes, e.g. an "end of turn" token may be
  964. added on chat models. In that case, we want the model configs to be aligned with the tokenizer, so that all
  965. downstream uses work as expected. This alignment should happen before training, to ensure the prediction step
  966. uses the new tokens as well.
  967. """
  968. from .processing_utils import ProcessorMixin
  969. from .tokenization_utils_base import PreTrainedTokenizerBase
  970. if isinstance(processing_class, ProcessorMixin):
  971. tokenizer: PreTrainedTokenizerBase = processing_class.tokenizer
  972. else:
  973. tokenizer = processing_class
  974. model_has_generation_config = hasattr(model, "generation_config") and model.generation_config is not None
  975. updated_tokens = {}
  976. # 1 - Align EOS token. EOS is more complex than the others, as `generation_config` may hold more than one EOS
  977. # token.
  978. tokenizer_has_new_eos = tokenizer.eos_token_id != getattr(model.config, "eos_token_id", None)
  979. if model_has_generation_config:
  980. # `generation_config.eos_token_id` is None: direct comparison
  981. if model.generation_config.eos_token_id is None:
  982. tokenizer_has_new_eos |= tokenizer.eos_token_id != model.generation_config.eos_token_id
  983. else:
  984. # `generation_config.eos_token_id` is an `int`: convert it to list (and continue below)
  985. if isinstance(model.generation_config.eos_token_id, int):
  986. model.generation_config.eos_token_id = [model.generation_config.eos_token_id]
  987. # `generation_config.eos_token_id` is a `list`: check if the tokenizer's EOS token is in the list
  988. tokenizer_has_new_eos |= tokenizer.eos_token_id not in model.generation_config.eos_token_id
  989. if tokenizer_has_new_eos:
  990. updated_tokens["eos_token_id"] = tokenizer.eos_token_id
  991. model.config.eos_token_id = tokenizer.eos_token_id
  992. # The generation config may hold more than one EOS token. We preserve the original EOS tokens: any of the
  993. # EOS tokens defined here will halt generation.
  994. if model_has_generation_config:
  995. all_eos_tokens = [tokenizer.eos_token_id]
  996. if model.generation_config.eos_token_id is not None:
  997. all_eos_tokens += list(model.generation_config.eos_token_id)
  998. model.generation_config.eos_token_id = [token for token in all_eos_tokens if token is not None]
  999. # 2 - Align BOS
  1000. tokenizer_has_new_bos = tokenizer.bos_token_id != getattr(model.config, "bos_token_id", None)
  1001. if model_has_generation_config:
  1002. tokenizer_has_new_bos |= tokenizer.bos_token_id != model.generation_config.bos_token_id
  1003. if tokenizer_has_new_bos:
  1004. updated_tokens["bos_token_id"] = tokenizer.bos_token_id
  1005. model.config.bos_token_id = tokenizer.bos_token_id
  1006. if model_has_generation_config:
  1007. model.generation_config.bos_token_id = tokenizer.bos_token_id
  1008. # 3 - Align PAD
  1009. tokenizer_has_new_pad = tokenizer.pad_token_id != getattr(model.config, "pad_token_id", None)
  1010. if model_has_generation_config:
  1011. tokenizer_has_new_pad |= tokenizer.pad_token_id != model.generation_config.pad_token_id
  1012. if tokenizer_has_new_pad:
  1013. updated_tokens["pad_token_id"] = tokenizer.pad_token_id
  1014. model.config.pad_token_id = tokenizer.pad_token_id
  1015. if model_has_generation_config:
  1016. model.generation_config.pad_token_id = tokenizer.pad_token_id
  1017. # 4 - Warn users about the changes
  1018. if len(updated_tokens) > 0:
  1019. logger.warning(
  1020. "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. "
  1021. "The model config and generation config were aligned accordingly, being updated with the tokenizer's "
  1022. f"values. Updated tokens: {updated_tokens}."
  1023. )
  1024. @contextlib.contextmanager
  1025. def suppress_progress_bars():
  1026. """Context manager that suppresses huggingface_hub progress bars."""
  1027. import huggingface_hub.utils as hf_hub_utils
  1028. hf_hub_utils.disable_progress_bars()
  1029. try:
  1030. yield
  1031. finally:
  1032. hf_hub_utils.enable_progress_bars()