trainer_pt_utils.py 56 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336
  1. # Copyright 2020-present the HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Torch utilities for the Trainer class.
  16. """
  17. import contextlib
  18. import copy
  19. import datetime
  20. import io
  21. import json
  22. import math
  23. import os
  24. import re
  25. import sys
  26. import warnings
  27. from collections.abc import Iterator, Mapping
  28. from contextlib import contextmanager
  29. from dataclasses import dataclass, field
  30. from itertools import chain
  31. from logging import StreamHandler
  32. from typing import Any
  33. import numpy as np
  34. import torch
  35. import torch.distributed as dist
  36. from packaging import version
  37. from torch import nn
  38. from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler
  39. from torch.utils.data.distributed import DistributedSampler
  40. from .integrations.deepspeed import is_deepspeed_zero3_enabled
  41. from .tokenization_utils_base import BatchEncoding
  42. from .utils import (
  43. is_sagemaker_mp_enabled,
  44. is_torch_available,
  45. is_torch_xla_available,
  46. is_training_run_on_sagemaker,
  47. logging,
  48. )
  49. if is_training_run_on_sagemaker():
  50. logging.add_handler(StreamHandler(sys.stdout))
  51. if is_torch_xla_available():
  52. import torch_xla.runtime as xr
  53. if is_torch_available():
  54. from torch.optim.lr_scheduler import LRScheduler
  55. logger = logging.get_logger(__name__)
  56. def get_dataloader_sampler(dataloader):
  57. if hasattr(dataloader, "batch_sampler") and dataloader.batch_sampler is not None:
  58. return get_dataloader_sampler(dataloader.batch_sampler)
  59. elif hasattr(dataloader, "sampler"):
  60. return dataloader.sampler
  61. def atleast_1d(tensor_or_array: torch.Tensor | np.ndarray):
  62. if isinstance(tensor_or_array, torch.Tensor):
  63. tensor_or_array = torch.atleast_1d(tensor_or_array)
  64. else:
  65. tensor_or_array = np.atleast_1d(tensor_or_array)
  66. return tensor_or_array
  67. def torch_pad_and_concatenate(tensor1, tensor2, padding_index=-100):
  68. """Concatenates `tensor1` and `tensor2` on first axis, applying padding on the second if necessary."""
  69. tensor1 = atleast_1d(tensor1)
  70. tensor2 = atleast_1d(tensor2)
  71. if len(tensor1.shape) == 1 or tensor1.shape[1] == tensor2.shape[1]:
  72. return torch.cat((tensor1, tensor2), dim=0)
  73. # Let's figure out the new shape
  74. new_shape = (tensor1.shape[0] + tensor2.shape[0], max(tensor1.shape[1], tensor2.shape[1])) + tensor1.shape[2:]
  75. # Now let's fill the result tensor
  76. result = tensor1.new_full(new_shape, padding_index)
  77. result[: tensor1.shape[0], : tensor1.shape[1]] = tensor1
  78. result[tensor1.shape[0] :, : tensor2.shape[1]] = tensor2
  79. return result
  80. def numpy_pad_and_concatenate(array1, array2, padding_index=-100):
  81. """Concatenates `array1` and `array2` on first axis, applying padding on the second if necessary."""
  82. array1 = atleast_1d(array1)
  83. array2 = atleast_1d(array2)
  84. if len(array1.shape) == 1 or array1.shape[1] == array2.shape[1]:
  85. return np.concatenate((array1, array2), axis=0)
  86. # Let's figure out the new shape
  87. new_shape = (array1.shape[0] + array2.shape[0], max(array1.shape[1], array2.shape[1])) + array1.shape[2:]
  88. # Now let's fill the result tensor
  89. result = np.full_like(array1, padding_index, shape=new_shape)
  90. result[: array1.shape[0], : array1.shape[1]] = array1
  91. result[array1.shape[0] :, : array2.shape[1]] = array2
  92. return result
  93. def nested_concat(tensors, new_tensors, padding_index=-100):
  94. """
  95. Concat the `new_tensors` to `tensors` on the first dim and pad them on the second if needed. Works for tensors or
  96. nested list/tuples/dict of tensors.
  97. """
  98. if not (isinstance(tensors, torch.Tensor) and isinstance(new_tensors, torch.Tensor)):
  99. assert type(tensors) is type(new_tensors), (
  100. f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
  101. )
  102. if isinstance(tensors, (list, tuple)):
  103. return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
  104. elif isinstance(tensors, torch.Tensor):
  105. return torch_pad_and_concatenate(tensors, new_tensors, padding_index=padding_index)
  106. elif isinstance(tensors, Mapping):
  107. return type(tensors)(
  108. {k: nested_concat(t, new_tensors[k], padding_index=padding_index) for k, t in tensors.items()}
  109. )
  110. elif isinstance(tensors, np.ndarray):
  111. return numpy_pad_and_concatenate(tensors, new_tensors, padding_index=padding_index)
  112. else:
  113. raise TypeError(f"Unsupported type for concatenation: got {type(tensors)}")
  114. def find_batch_size(tensors):
  115. """
  116. Find the first dimension of a tensor in a nested list/tuple/dict of tensors.
  117. """
  118. if isinstance(tensors, (list, tuple)):
  119. for t in tensors:
  120. result = find_batch_size(t)
  121. if result is not None:
  122. return result
  123. elif isinstance(tensors, Mapping):
  124. for value in tensors.values():
  125. result = find_batch_size(value)
  126. if result is not None:
  127. return result
  128. elif isinstance(tensors, (torch.Tensor, np.ndarray)):
  129. return tensors.shape[0] if len(tensors.shape) >= 1 else None
  130. def nested_numpify(tensors):
  131. "Numpify `tensors` (even if it's a nested list/tuple/dict of tensors)."
  132. if isinstance(tensors, (list, tuple)):
  133. return type(tensors)(nested_numpify(t) for t in tensors)
  134. if isinstance(tensors, Mapping):
  135. return type(tensors)({k: nested_numpify(t) for k, t in tensors.items()})
  136. t = tensors.cpu()
  137. if t.dtype == torch.bfloat16:
  138. # As of Numpy 1.21.4, NumPy does not support bfloat16 (see
  139. # https://github.com/numpy/numpy/blob/a47ecdea856986cd60eabbd53265c2ca5916ad5d/doc/source/user/basics.types.rst ).
  140. # Until Numpy adds bfloat16, we must convert float32.
  141. t = t.to(torch.float32)
  142. return t.numpy()
  143. def nested_detach(tensors):
  144. "Detach `tensors` (even if it's a nested list/tuple/dict of tensors)."
  145. if isinstance(tensors, (list, tuple)):
  146. return type(tensors)(nested_detach(t) for t in tensors)
  147. elif isinstance(tensors, Mapping):
  148. return type(tensors)({k: nested_detach(t) for k, t in tensors.items()})
  149. return tensors.detach() if isinstance(tensors, torch.Tensor) else tensors
  150. def nested_xla_mesh_reduce(tensors, name):
  151. if is_torch_xla_available():
  152. import torch_xla.core.xla_model as xm
  153. if isinstance(tensors, (list, tuple)):
  154. return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
  155. if isinstance(tensors, Mapping):
  156. return type(tensors)(
  157. {k: nested_xla_mesh_reduce(t, f"{name}_{i}") for i, (k, t) in enumerate(tensors.items())}
  158. )
  159. tensors = atleast_1d(tensors)
  160. return xm.mesh_reduce(name, tensors, torch.cat)
  161. else:
  162. raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")
  163. def distributed_concat(tensor: Any, num_total_examples: int | None = None) -> Any:
  164. try:
  165. if isinstance(tensor, (tuple, list)):
  166. return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
  167. if isinstance(tensor, Mapping):
  168. return type(tensor)({k: distributed_concat(t, num_total_examples) for k, t in tensor.items()})
  169. tensor = atleast_1d(tensor).contiguous()
  170. output_tensors = [tensor.clone() for _ in range(dist.get_world_size())]
  171. dist.all_gather(output_tensors, tensor)
  172. concat = torch.cat(output_tensors, dim=0)
  173. # truncate the dummy elements added by SequentialDistributedSampler
  174. if num_total_examples is not None:
  175. concat = concat[:num_total_examples]
  176. return concat
  177. except AssertionError:
  178. raise AssertionError("Not currently using distributed training")
  179. def nested_gather(tensors, parallel_mode, name=None):
  180. """
  181. Gather value of `tensors` (tensor or list/tuple of nested tensors) across processes.
  182. """
  183. from .training_args import ParallelMode
  184. if tensors is None:
  185. return
  186. if is_torch_xla_available():
  187. if name is None:
  188. name = "nested_gather"
  189. tensors = nested_xla_mesh_reduce(tensors, name)
  190. elif is_sagemaker_mp_enabled():
  191. tensors = smp_gather(tensors)
  192. elif parallel_mode == ParallelMode.DISTRIBUTED:
  193. tensors = distributed_concat(tensors)
  194. return tensors
  195. def is_attention_mask_causal(attention_mask):
  196. """
  197. Check if an attention mask is causal (compatible with causal attention).
  198. Context parallelism only supports causal attention patterns. This function
  199. checks if the provided attention mask is compatible.
  200. Args:
  201. attention_mask (`torch.Tensor`): The attention mask to check.
  202. Returns:
  203. `bool`: True if the mask is causal or compatible with causal attention.
  204. """
  205. if attention_mask is None:
  206. return True # No mask is considered causal (model uses default causal masking)
  207. # Handle different mask dimensions
  208. if attention_mask.dim() == 2:
  209. # (batch_size, seq_len) - standard padding mask, compatible with causal attention
  210. return True
  211. elif attention_mask.dim() in [3, 4]:
  212. # (batch_size, seq_len, seq_len) or (batch_size, num_heads, seq_len, seq_len)
  213. # Check if it's lower triangular (causal)
  214. seq_len = attention_mask.shape[-1]
  215. if seq_len <= 1:
  216. return True # Single token or empty is always causal
  217. # Take first batch and head (if 4D) for checking pattern
  218. if attention_mask.dim() == 4:
  219. mask = attention_mask[0, 0] # First batch, first head
  220. else:
  221. mask = attention_mask[0] # First batch
  222. # Check if upper triangular part is masked (should be 0 or very negative for causal)
  223. upper_triangular = torch.triu(mask, diagonal=1)
  224. # For causal masks, upper triangular should be 0 or very negative (like -inf)
  225. # Use a reasonable threshold to handle float precision issues
  226. is_causal = torch.all(upper_triangular <= 1e-6) or torch.all(upper_triangular < -1e4)
  227. return is_causal.item() if isinstance(is_causal, torch.Tensor) else is_causal
  228. # For unknown dimensions, be conservative and reject
  229. return False
  230. def distributed_broadcast_scalars(
  231. scalars: list[int | float],
  232. num_total_examples: int | None = None,
  233. device: torch.device | None = torch.device("cuda"),
  234. ) -> torch.Tensor:
  235. try:
  236. tensorized_scalar = torch.tensor(scalars, device=device)
  237. output_tensors = [tensorized_scalar.clone() for _ in range(dist.get_world_size())]
  238. dist.all_gather(output_tensors, tensorized_scalar)
  239. concat = torch.cat(output_tensors, dim=0)
  240. # truncate the dummy elements added by SequentialDistributedSampler
  241. if num_total_examples is not None:
  242. concat = concat[:num_total_examples]
  243. return concat
  244. except AssertionError:
  245. raise AssertionError("Not currently using distributed training")
  246. def reissue_pt_warnings(caught_warnings):
  247. # Reissue warnings
  248. if len(caught_warnings) > 1:
  249. for w in caught_warnings:
  250. if w.category is not UserWarning:
  251. warnings.warn(w.message, w.category)
  252. @contextmanager
  253. def torch_distributed_zero_first(local_rank: int):
  254. """
  255. Decorator to make all processes in distributed training wait for each local_master to do something.
  256. Args:
  257. local_rank (`int`): The rank of the local process.
  258. """
  259. if local_rank not in [-1, 0]:
  260. dist.barrier()
  261. yield
  262. if local_rank == 0:
  263. dist.barrier()
  264. class DistributedSamplerWithLoop(DistributedSampler):
  265. """
  266. Like a torch.utils.data.distributed.DistributedSampler` but loops at the end back to the beginning of the shuffled
  267. samples to make each process have a round multiple of batch_size samples.
  268. Args:
  269. dataset (`torch.utils.data.Dataset`):
  270. Dataset used for sampling.
  271. batch_size (`int`):
  272. The batch size used with this sampler
  273. kwargs (`dict[str, Any]`, *optional*):
  274. All other keyword arguments passed to `DistributedSampler`.
  275. """
  276. def __init__(self, dataset, batch_size, **kwargs):
  277. super().__init__(dataset, **kwargs)
  278. self.batch_size = batch_size
  279. def __iter__(self):
  280. indices = list(super().__iter__())
  281. remainder = 0 if len(indices) % self.batch_size == 0 else self.batch_size - len(indices) % self.batch_size
  282. # DistributedSampler already added samples from the beginning to make the number of samples a round multiple
  283. # of the world size, so we skip those.
  284. start_remainder = 1 if self.rank < len(self.dataset) % self.num_replicas else 0
  285. indices += indices[start_remainder : start_remainder + remainder]
  286. return iter(indices)
  287. class EvalLoopContainer:
  288. """
  289. Container to store intermediate results of evaluation loop.
  290. Args:
  291. do_nested_concat (`bool`, *optional*, defaults to `True`):
  292. If set to `True`, each iteration will recursively concatenate a new object containing tensors to
  293. the existing stored tensors, provided that the structure of the existing object and the new one
  294. are identical. If set to `False`, all newly added tensors will be stored in a list.
  295. padding_index (`int`, *optional*, defaults to -100):
  296. Value used to pad tensors of different shapes when `do_nested_concat=True`.
  297. """
  298. def __init__(self, do_nested_concat: bool = True, padding_index: int = -100):
  299. self.do_nested_concat = do_nested_concat
  300. self.padding_index = padding_index
  301. self.tensors = None
  302. self.arrays = None
  303. def add(self, tensors) -> None:
  304. """Add tensors to the stored objects. If `do_nested_concat=True`, the tensors will be concatenated recursively."""
  305. if self.tensors is None:
  306. self.tensors = tensors if self.do_nested_concat else [tensors]
  307. elif self.do_nested_concat:
  308. self.tensors = nested_concat(self.tensors, tensors, padding_index=self.padding_index)
  309. else:
  310. self.tensors.append(tensors)
  311. def to_cpu_and_numpy(self) -> None:
  312. """Move tensors in stored objects to CPU and convert them to numpy arrays."""
  313. # Check if we have something to add, if not just return
  314. if self.tensors is None:
  315. return
  316. new_arrays = nested_numpify(self.tensors)
  317. if self.arrays is None:
  318. self.arrays = new_arrays
  319. elif self.do_nested_concat:
  320. self.arrays = nested_concat(self.arrays, new_arrays, padding_index=self.padding_index)
  321. else:
  322. self.arrays.extend(new_arrays)
  323. # reset device tensors after adding to cpu
  324. self.tensors = None
  325. def get_arrays(self):
  326. """Returns the numpified and moved to CPU stored objects."""
  327. self.to_cpu_and_numpy()
  328. return self.arrays
  329. def get_tpu_sampler(dataset: torch.utils.data.Dataset, batch_size: int):
  330. if xr.world_size() <= 1:
  331. return RandomSampler(dataset)
  332. return DistributedSampler(dataset, num_replicas=xr.world_size(), rank=xr.global_ordinal())
  333. def nested_new_like(arrays, num_samples, padding_index=-100):
  334. """Create the same nested structure as `arrays` with a first dimension always at `num_samples`."""
  335. if isinstance(arrays, (list, tuple)):
  336. return type(arrays)(nested_new_like(x, num_samples) for x in arrays)
  337. return np.full_like(arrays, padding_index, shape=(num_samples, *arrays.shape[1:]))
  338. def expand_like(arrays, new_seq_length, padding_index=-100):
  339. """Expand the `arrays` so that the second dimension grows to `new_seq_length`. Uses `padding_index` for padding."""
  340. result = np.full_like(arrays, padding_index, shape=(arrays.shape[0], new_seq_length) + arrays.shape[2:])
  341. result[:, : arrays.shape[1]] = arrays
  342. return result
  343. def nested_truncate(tensors, limit):
  344. "Truncate `tensors` at `limit` (even if it's a nested list/tuple/dict of tensors)."
  345. if isinstance(tensors, (list, tuple)):
  346. return type(tensors)(nested_truncate(t, limit) for t in tensors)
  347. if isinstance(tensors, Mapping):
  348. return type(tensors)({k: nested_truncate(t, limit) for k, t in tensors.items()})
  349. return tensors[:limit]
  350. @dataclass
  351. class LabelSmoother:
  352. """
  353. Adds label-smoothing on a pre-computed output from a Transformers model.
  354. Args:
  355. epsilon (`float`, *optional*, defaults to 0.1):
  356. The label smoothing factor.
  357. ignore_index (`int`, *optional*, defaults to -100):
  358. The index in the labels to ignore when computing the loss.
  359. """
  360. epsilon: float = 0.1
  361. ignore_index: int = -100
  362. def __call__(self, model_output, labels, shift_labels=False):
  363. logits = model_output["logits"] if isinstance(model_output, dict) else model_output[0]
  364. if shift_labels:
  365. logits = logits[..., :-1, :].contiguous()
  366. labels = labels[..., 1:].contiguous()
  367. log_probs = -nn.functional.log_softmax(logits, dim=-1)
  368. if labels.dim() == log_probs.dim() - 1:
  369. labels = labels.unsqueeze(-1)
  370. padding_mask = labels.eq(self.ignore_index)
  371. # In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask
  372. # will ignore them in any case.
  373. labels = torch.clamp(labels, min=0)
  374. nll_loss = log_probs.gather(dim=-1, index=labels)
  375. # works for fp16 input tensor too, by internally upcasting it to fp32
  376. smoothed_loss = log_probs.sum(dim=-1, keepdim=True, dtype=torch.float32)
  377. nll_loss.masked_fill_(padding_mask, 0.0)
  378. smoothed_loss.masked_fill_(padding_mask, 0.0)
  379. # Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
  380. num_active_elements = padding_mask.numel() - padding_mask.long().sum()
  381. nll_loss = nll_loss.sum() / num_active_elements
  382. smoothed_loss = smoothed_loss.sum() / (num_active_elements * log_probs.shape[-1])
  383. return (1 - self.epsilon) * nll_loss + self.epsilon * smoothed_loss
  384. def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None, generator=None):
  385. """
  386. Return a list of indices so that each slice of `batch_size` consecutive indices correspond to elements of similar
  387. lengths. To do this, the indices are:
  388. - randomly permuted
  389. - grouped in mega-batches of size `mega_batch_mult * batch_size`
  390. - sorted by length in each mega-batch
  391. The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
  392. maximum length placed first, so that an OOM happens sooner rather than later.
  393. """
  394. # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
  395. if mega_batch_mult is None:
  396. mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
  397. # Just in case, for tiny datasets
  398. if mega_batch_mult == 0:
  399. mega_batch_mult = 1
  400. # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
  401. indices = torch.randperm(len(lengths), generator=generator)
  402. megabatch_size = mega_batch_mult * batch_size
  403. megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
  404. megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
  405. # The rest is to get the biggest batch first.
  406. # Since each megabatch is sorted by descending length, the longest element is the first
  407. megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
  408. max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item()
  409. # Switch to put the longest element in first position
  410. megabatches[0][0], megabatches[max_idx][0] = megabatches[max_idx][0], megabatches[0][0]
  411. return [i for megabatch in megabatches for i in megabatch]
  412. class LengthGroupedSampler(Sampler):
  413. r"""
  414. Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
  415. keeping a bit of randomness.
  416. """
  417. def __init__(
  418. self,
  419. batch_size: int,
  420. dataset: Dataset | None = None,
  421. lengths: list[int] | None = None,
  422. model_input_name: str | None = None,
  423. generator=None,
  424. ):
  425. if dataset is None and lengths is None:
  426. raise ValueError("One of dataset and lengths must be provided.")
  427. self.batch_size = batch_size
  428. if lengths is None:
  429. model_input_name = model_input_name if model_input_name is not None else "input_ids"
  430. if not isinstance(dataset[0], (dict, BatchEncoding)) or model_input_name not in dataset[0]:
  431. raise ValueError(
  432. "Can only automatically infer lengths for datasets whose items are dictionaries with an "
  433. f"'{model_input_name}' key."
  434. )
  435. lengths = [len(feature[model_input_name]) for feature in dataset]
  436. elif isinstance(lengths, torch.Tensor):
  437. logger.info(
  438. "If lengths is a torch.Tensor, LengthGroupedSampler will be slow. Converting lengths to list[int]..."
  439. )
  440. lengths = lengths.tolist()
  441. self.lengths = lengths
  442. self.generator = generator
  443. def __len__(self):
  444. return len(self.lengths)
  445. def __iter__(self):
  446. indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=self.generator)
  447. return iter(indices)
  448. class DistributedLengthGroupedSampler(DistributedSampler):
  449. r"""
  450. Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
  451. length while keeping a bit of randomness.
  452. """
  453. # Copied and adapted from PyTorch DistributedSampler.
  454. def __init__(
  455. self,
  456. batch_size: int,
  457. dataset: Dataset | None = None,
  458. num_replicas: int | None = None,
  459. rank: int | None = None,
  460. seed: int = 0,
  461. drop_last: bool = False,
  462. lengths: list[int] | None = None,
  463. model_input_name: str | None = None,
  464. ):
  465. if dataset is None and lengths is None:
  466. raise ValueError("One of dataset and lengths must be provided.")
  467. if num_replicas is None:
  468. if not dist.is_available():
  469. raise RuntimeError("Requires distributed package to be available")
  470. num_replicas = dist.get_world_size()
  471. if rank is None:
  472. if not dist.is_available():
  473. raise RuntimeError("Requires distributed package to be available")
  474. rank = dist.get_rank()
  475. self.batch_size = batch_size
  476. self.num_replicas = num_replicas
  477. self.rank = rank
  478. self.epoch = 0
  479. self.drop_last = drop_last
  480. if lengths is None:
  481. model_input_name = model_input_name if model_input_name is not None else "input_ids"
  482. if not isinstance(dataset[0], (dict, BatchEncoding)) or model_input_name not in dataset[0]:
  483. raise ValueError(
  484. "Can only automatically infer lengths for datasets whose items are dictionaries with an "
  485. f"'{model_input_name}' key."
  486. )
  487. lengths = [len(feature[model_input_name]) for feature in dataset]
  488. elif isinstance(lengths, torch.Tensor):
  489. logger.info(
  490. "If lengths is a torch.Tensor, DistributedLengthGroupedSampler will be slow. Converting lengths to"
  491. " list[int]..."
  492. )
  493. lengths = lengths.tolist()
  494. self.lengths = lengths
  495. # If the dataset length is evenly divisible by # of replicas, then there
  496. # is no need to drop any data, since the dataset will be split equally.
  497. if self.drop_last and len(self.lengths) % self.num_replicas != 0:
  498. # Split to nearest available length that is evenly divisible.
  499. # This is to ensure each rank receives the same amount of data when
  500. # using this Sampler.
  501. self.num_samples = math.ceil((len(self.lengths) - self.num_replicas) / self.num_replicas)
  502. else:
  503. self.num_samples = math.ceil(len(self.lengths) / self.num_replicas)
  504. self.total_size = self.num_samples * self.num_replicas
  505. self.seed = seed
  506. def __iter__(self) -> Iterator:
  507. # Deterministically shuffle based on epoch and seed
  508. g = torch.Generator()
  509. g.manual_seed(self.seed + self.epoch)
  510. indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
  511. if not self.drop_last:
  512. # add extra samples to make it evenly divisible
  513. indices += indices[: (self.total_size - len(indices))]
  514. else:
  515. # remove tail of data to make it evenly divisible
  516. indices = indices[: self.total_size]
  517. assert len(indices) == self.total_size
  518. # subsample
  519. indices = indices[self.rank : self.total_size : self.num_replicas]
  520. assert len(indices) == self.num_samples
  521. return iter(indices)
  522. class ShardSampler(Sampler):
  523. """
  524. Sampler that shards batches between several processes. Dispatches indices batch by batch: on 2 processes with batch
  525. size 4, the first two batches are `[0, 1, 2, 3, 4, 5, 6, 7]` and `[8, 9, 10, 11, 12, 13, 14, 15]`, which shard into
  526. `[0, 1, 2, 3]` and `[8, 9, 10, 11]` for GPU-0 and `[4, 5, 6, 7]` and `[12, 13, 14, 15]` for GPU-1.
  527. The sampler thus yields `[0, 1, 2, 3, 8, 9, 10, 11]` on GPU-0 and `[4, 5, 6, 7, 12, 13, 14, 15]` on GPU-1.
  528. """
  529. def __init__(
  530. self,
  531. dataset: Dataset,
  532. batch_size: int = 1,
  533. drop_last: bool = False,
  534. num_processes: int = 1,
  535. process_index: int = 0,
  536. ):
  537. self.dataset = dataset
  538. self.batch_size = batch_size
  539. self.drop_last = drop_last
  540. self.num_processes = num_processes
  541. self.process_index = process_index
  542. self.total_batch_size = total_batch_size = batch_size * num_processes
  543. num_batches = len(dataset) // total_batch_size if drop_last else math.ceil(len(dataset) / total_batch_size)
  544. self.total_num_samples = num_batches * total_batch_size
  545. def __iter__(self):
  546. indices = list(range(len(self.dataset)))
  547. # Add extra samples to make it evenly divisible. While loop is there in the edge case we have a tiny dataset
  548. # and it needs to be done several times.
  549. while len(indices) < self.total_num_samples:
  550. indices += indices[: (self.total_num_samples - len(indices))]
  551. result = []
  552. for batch_start in range(self.batch_size * self.process_index, self.total_num_samples, self.total_batch_size):
  553. result += indices[batch_start : batch_start + self.batch_size]
  554. return iter(result)
  555. def __len__(self):
  556. # Each shard only sees a fraction of total_num_samples.
  557. return self.total_num_samples // self.num_processes
  558. class IterableDatasetShard(IterableDataset):
  559. """
  560. Wraps a PyTorch `IterableDataset` to generate samples for one of the processes only. Instances of this class will
  561. always yield a number of samples that is a round multiple of the actual batch size (which is `batch_size x
  562. num_processes`). Depending on the value of the `drop_last` attribute, it will either stop the iteration at the
  563. first batch that would be too small or loop with indices from the beginning.
  564. On two processes with an iterable dataset yielding of `[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]` with a batch size of
  565. 2:
  566. - the shard on process 0 will yield `[0, 1, 4, 5, 8, 9]` so will see batches `[0, 1]`, `[4, 5]`, `[8, 9]`
  567. - the shard on process 1 will yield `[2, 3, 6, 7, 10, 11]` so will see batches `[2, 3]`, `[6, 7]`, `[10, 11]`
  568. <Tip warning={true}>
  569. If your IterableDataset implements some randomization that needs to be applied the same way on all processes
  570. (for instance, a shuffling), you should use a `torch.Generator` in a `generator` attribute of the `dataset` to
  571. generate your random numbers and call the [`~trainer_pt_utils.IterableDatasetShard.set_epoch`] method of this
  572. object. It will set the seed of this `generator` to `seed + epoch` on all processes before starting the
  573. iteration. Alternatively, you can also implement a `set_epoch()` method in your iterable dataset to deal with
  574. this.
  575. </Tip>
  576. Args:
  577. dataset (`torch.utils.data.IterableDataset`):
  578. The batch sampler to split in several shards.
  579. batch_size (`int`, *optional*, defaults to 1):
  580. The size of the batches per shard.
  581. drop_last (`bool`, *optional*, defaults to `False`):
  582. Whether or not to drop the last incomplete batch or complete the last batches by using the samples from the
  583. beginning.
  584. num_processes (`int`, *optional*, defaults to 1):
  585. The number of processes running concurrently.
  586. process_index (`int`, *optional*, defaults to 0):
  587. The index of the current process.
  588. seed (`int`, *optional*, defaults to 0):
  589. A random seed that will be used for the random number generation in
  590. [`~trainer_pt_utils.IterableDatasetShard.set_epoch`].
  591. """
  592. def __init__(
  593. self,
  594. dataset: IterableDataset,
  595. batch_size: int = 1,
  596. drop_last: bool = False,
  597. num_processes: int = 1,
  598. process_index: int = 0,
  599. seed: int = 0,
  600. ):
  601. self.dataset = dataset
  602. self.batch_size = batch_size
  603. self.drop_last = drop_last
  604. self.num_processes = num_processes
  605. self.process_index = process_index
  606. self.seed = seed
  607. self.epoch = 0
  608. self.num_examples = 0
  609. def set_epoch(self, epoch):
  610. self.epoch = epoch
  611. if hasattr(self.dataset, "set_epoch"):
  612. self.dataset.set_epoch(epoch)
  613. def __iter__(self):
  614. self.num_examples = 0
  615. if (
  616. not hasattr(self.dataset, "set_epoch")
  617. and hasattr(self.dataset, "generator")
  618. and isinstance(self.dataset.generator, torch.Generator)
  619. ):
  620. self.dataset.generator.manual_seed(self.seed + self.epoch)
  621. real_batch_size = self.batch_size * self.num_processes
  622. process_slice = range(self.process_index * self.batch_size, (self.process_index + 1) * self.batch_size)
  623. first_batch = None
  624. current_batch = []
  625. for element in self.dataset:
  626. self.num_examples += 1
  627. current_batch.append(element)
  628. # Wait to have a full batch before yielding elements.
  629. if len(current_batch) == real_batch_size:
  630. for i in process_slice:
  631. yield current_batch[i]
  632. if first_batch is None:
  633. first_batch = current_batch.copy()
  634. current_batch = []
  635. # Finished if drop_last is True, otherwise complete the last batch with elements from the beginning.
  636. if not self.drop_last and len(current_batch) > 0:
  637. if first_batch is None:
  638. first_batch = current_batch.copy()
  639. while len(current_batch) < real_batch_size:
  640. current_batch += first_batch
  641. for i in process_slice:
  642. yield current_batch[i]
  643. def __len__(self):
  644. # Will raise an error if the underlying dataset is not sized.
  645. if self.drop_last:
  646. return (len(self.dataset) // (self.batch_size * self.num_processes)) * self.batch_size
  647. else:
  648. return math.ceil(len(self.dataset) / (self.batch_size * self.num_processes)) * self.batch_size
  649. def _secs2timedelta(secs):
  650. """
  651. Convert seconds to hh:mm:ss.msec, msecs rounded to 2 decimal places.
  652. """
  653. msec = int(abs(secs - int(secs)) * 100)
  654. return f"{datetime.timedelta(seconds=int(secs))}.{msec:02d}"
  655. def metrics_format(metrics: dict[str, float]) -> dict[str, float]:
  656. """
  657. Reformat Trainer metrics values to a human-readable format.
  658. Args:
  659. metrics (`dict[str, float]`):
  660. The metrics returned from train/evaluate/predict
  661. Returns:
  662. metrics (`dict[str, float]`): The reformatted metrics
  663. """
  664. metrics_copy = metrics.copy()
  665. for k, v in metrics_copy.items():
  666. if "_mem_" in k:
  667. metrics_copy[k] = f"{v >> 20}MB"
  668. elif "_runtime" in k:
  669. metrics_copy[k] = _secs2timedelta(v)
  670. elif k == "total_flos":
  671. metrics_copy[k] = f"{int(v) >> 30}GF"
  672. elif isinstance(metrics_copy[k], float):
  673. metrics_copy[k] = round(v, 4)
  674. return metrics_copy
  675. # Trainer helper method: imported into the Trainer class and used as a method (takes `self` as first argument).
  676. def log_metrics(self, split, metrics):
  677. """
  678. Log metrics in a specially formatted way.
  679. Under distributed environment this is done only for a process with rank 0.
  680. Args:
  681. split (`str`):
  682. Mode/split name: one of `train`, `eval`, `test`
  683. metrics (`dict[str, float]`):
  684. The metrics returned from train/evaluate/predictmetrics: metrics dict
  685. Notes on memory reports:
  686. In order to get memory usage report you need to install `psutil`. You can do that with `pip install psutil`.
  687. Now when this method is run, you will see a report that will include:
  688. ```
  689. init_mem_cpu_alloc_delta = 1301MB
  690. init_mem_cpu_peaked_delta = 154MB
  691. init_mem_gpu_alloc_delta = 230MB
  692. init_mem_gpu_peaked_delta = 0MB
  693. train_mem_cpu_alloc_delta = 1345MB
  694. train_mem_cpu_peaked_delta = 0MB
  695. train_mem_gpu_alloc_delta = 693MB
  696. train_mem_gpu_peaked_delta = 7MB
  697. ```
  698. **Understanding the reports:**
  699. - the first segment, e.g., `train__`, tells you which stage the metrics are for. Reports starting with `init_`
  700. will be added to the first stage that gets run. So that if only evaluation is run, the memory usage for the
  701. `__init__` will be reported along with the `eval_` metrics.
  702. - the third segment, is either `cpu` or `gpu`, tells you whether it's the general RAM or the gpu0 memory
  703. metric.
  704. - `*_alloc_delta` - is the difference in the used/allocated memory counter between the end and the start of the
  705. stage - it can be negative if a function released more memory than it allocated.
  706. - `*_peaked_delta` - is any extra memory that was consumed and then freed - relative to the current allocated
  707. memory counter - it is never negative. When you look at the metrics of any stage you add up `alloc_delta` +
  708. `peaked_delta` and you know how much memory was needed to complete that stage.
  709. The reporting happens only for process of rank 0 and gpu 0 (if there is a gpu). Typically this is enough since the
  710. main process does the bulk of work, but it could be not quite so if model parallel is used and then other GPUs may
  711. use a different amount of gpu memory. This is also not the same under DataParallel where gpu0 may require much more
  712. memory than the rest since it stores the gradient and optimizer states for all participating GPUs. Perhaps in the
  713. future these reports will evolve to measure those too.
  714. The CPU RAM metric measures RSS (Resident Set Size) includes both the memory which is unique to the process and the
  715. memory shared with other processes. It is important to note that it does not include swapped out memory, so the
  716. reports could be imprecise.
  717. The CPU peak memory is measured using a sampling thread. Due to python's GIL it may miss some of the peak memory if
  718. that thread didn't get a chance to run when the highest memory was used. Therefore this report can be less than
  719. reality. Using `tracemalloc` would have reported the exact peak memory, but it doesn't report memory allocations
  720. outside of python. So if some C++ CUDA extension allocated its own memory it won't be reported. And therefore it
  721. was dropped in favor of the memory sampling approach, which reads the current process memory usage.
  722. The GPU allocated and peak memory reporting is done with `torch.cuda.memory_allocated()` and
  723. `torch.cuda.max_memory_allocated()`. This metric reports only "deltas" for pytorch-specific allocations, as
  724. `torch.cuda` memory management system doesn't track any memory allocated outside of pytorch. For example, the very
  725. first cuda call typically loads CUDA kernels, which may take from 0.5 to 2GB of GPU memory.
  726. Note that this tracker doesn't account for memory allocations outside of [`Trainer`]'s `__init__`, `train`,
  727. `evaluate` and `predict` calls.
  728. Because `evaluation` calls may happen during `train`, we can't handle nested invocations because
  729. `torch.cuda.max_memory_allocated` is a single counter, so if it gets reset by a nested eval call, `train`'s tracker
  730. will report incorrect info. If this [pytorch issue](https://github.com/pytorch/pytorch/issues/16266) gets resolved
  731. it will be possible to change this class to be re-entrant. Until then we will only track the outer level of
  732. `train`, `evaluate` and `predict` methods. Which means that if `eval` is called during `train`, it's the latter
  733. that will account for its memory usage and that of the former.
  734. This also means that if any other tool that is used along the [`Trainer`] calls
  735. `torch.cuda.reset_peak_memory_stats`, the gpu peak memory stats could be invalid. And the [`Trainer`] will disrupt
  736. the normal behavior of any such tools that rely on calling `torch.cuda.reset_peak_memory_stats` themselves.
  737. For best performance you may want to consider turning the memory profiling off for production runs.
  738. """
  739. if not self.is_world_process_zero():
  740. return
  741. print(f"***** {split} metrics *****")
  742. metrics_formatted = metrics_format(metrics)
  743. k_width = max(len(str(x)) for x in metrics_formatted)
  744. v_width = max(len(str(x)) for x in metrics_formatted.values())
  745. for key in sorted(metrics_formatted.keys()):
  746. print(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
  747. # Trainer helper method
  748. def save_metrics(self, split, metrics, combined=True):
  749. """
  750. Save metrics into a json file for that split, e.g. `train_results.json`.
  751. Under distributed environment this is done only for a process with rank 0.
  752. Args:
  753. split (`str`):
  754. Mode/split name: one of `train`, `eval`, `test`, `all`
  755. metrics (`dict[str, float]`):
  756. The metrics returned from train/evaluate/predict
  757. combined (`bool`, *optional*, defaults to `True`):
  758. Creates combined metrics by updating `all_results.json` with metrics of this call
  759. To understand the metrics please read the docstring of [`~Trainer.log_metrics`]. The only difference is that raw
  760. unformatted numbers are saved in the current method.
  761. """
  762. if not self.is_world_process_zero():
  763. return
  764. path = os.path.join(self.args.output_dir, f"{split}_results.json")
  765. with open(path, "w") as f:
  766. json.dump(metrics, f, indent=4, sort_keys=True)
  767. if combined:
  768. path = os.path.join(self.args.output_dir, "all_results.json")
  769. if os.path.exists(path):
  770. with open(path) as f:
  771. all_metrics = json.load(f)
  772. else:
  773. all_metrics = {}
  774. all_metrics.update(metrics)
  775. with open(path, "w") as f:
  776. json.dump(all_metrics, f, indent=4, sort_keys=True)
  777. # Trainer helper method
  778. def save_state(self):
  779. """
  780. Saves the Trainer state, since Trainer.save_model saves only the tokenizer with the model.
  781. Under distributed environment this is done only for a process with rank 0.
  782. """
  783. if not self.is_world_process_zero():
  784. return
  785. path = os.path.join(self.args.output_dir, "trainer_state.json")
  786. self.state.save_to_json(path)
  787. # Trainer helper method
  788. def get_num_trainable_parameters(self) -> int:
  789. """
  790. Get the number of trainable parameters.
  791. """
  792. return sum(p.numel() for p in self.model.parameters() if p.requires_grad)
  793. # Trainer helper method
  794. def get_learning_rates(self) -> list[float]:
  795. """
  796. Returns the learning rate of each parameter from self.optimizer.
  797. """
  798. if self.optimizer is None:
  799. raise ValueError("Trainer optimizer is None, please make sure you have setup the optimizer before.")
  800. return [group["lr"] for group in self.optimizer.param_groups]
  801. # Trainer helper method
  802. def get_optimizer_group(self, param: str | torch.nn.parameter.Parameter | None = None):
  803. """
  804. Returns optimizer group for a parameter if given, else returns all optimizer groups for params.
  805. Args:
  806. param (`str` or `torch.nn.parameter.Parameter`, *optional*):
  807. The parameter for which optimizer group needs to be returned.
  808. """
  809. if self.optimizer is None:
  810. raise ValueError("Trainer optimizer is None, please make sure you have setup the optimizer before.")
  811. if param is not None:
  812. for group in self.optimizer.param_groups:
  813. if param in group["params"]:
  814. return group
  815. return [group["params"] for group in self.optimizer.param_groups]
  816. def get_model_param_count(model, trainable_only=False):
  817. """
  818. Calculate model's total param count. If trainable_only is True then count only those requiring grads.
  819. """
  820. if is_deepspeed_zero3_enabled():
  821. def numel(p):
  822. return p.ds_numel if hasattr(p, "ds_numel") else p.numel()
  823. else:
  824. def numel(p):
  825. return p.numel()
  826. return sum(numel(p) for p in model.parameters() if not trainable_only or p.requires_grad)
  827. def get_parameter_names(model, forbidden_layer_types, forbidden_layer_names=None):
  828. """
  829. Returns the names of the model parameters that are not inside a forbidden layer.
  830. """
  831. forbidden_layer_patterns = (
  832. [re.compile(pattern) for pattern in forbidden_layer_names] if forbidden_layer_names is not None else []
  833. )
  834. result = []
  835. for name, child in model.named_children():
  836. child_params = get_parameter_names(child, forbidden_layer_types, forbidden_layer_names)
  837. result += [
  838. f"{name}.{n}"
  839. for n in child_params
  840. if not isinstance(child, tuple(forbidden_layer_types))
  841. and not any(pattern.search(f"{name}.{n}".lower()) for pattern in forbidden_layer_patterns)
  842. ]
  843. # Add model specific parameters that are not in any child
  844. result += [
  845. k for k in model._parameters if not any(pattern.search(k.lower()) for pattern in forbidden_layer_patterns)
  846. ]
  847. return result
  848. def get_module_class_from_name(module, name):
  849. """
  850. Gets a class from a module by its name.
  851. Args:
  852. module (`torch.nn.Module`): The module to get the class from.
  853. name (`str`): The name of the class.
  854. """
  855. modules_children = list(module.children())
  856. if module.__class__.__name__ == name:
  857. return module.__class__
  858. elif len(modules_children) == 0:
  859. return
  860. else:
  861. for child_module in modules_children:
  862. module_class = get_module_class_from_name(child_module, name)
  863. if module_class is not None:
  864. return module_class
  865. def remove_dummy_checkpoint(is_main_process, output_dir, filenames):
  866. if is_main_process:
  867. for filename in filenames:
  868. file = os.path.join(output_dir, filename)
  869. if os.path.isfile(file):
  870. os.remove(file)
  871. if is_sagemaker_mp_enabled():
  872. import smdistributed.modelparallel.torch as smp
  873. @smp.step()
  874. def smp_forward_backward(model, inputs, gradient_accumulation_steps=1):
  875. outputs = model(**inputs)
  876. loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
  877. loss /= gradient_accumulation_steps
  878. model.backward(loss)
  879. return loss
  880. @smp.step()
  881. def smp_forward_only(model, inputs):
  882. return model(**inputs)
  883. def smp_gather(tensor):
  884. if isinstance(tensor, (list, tuple)):
  885. return type(tensor)(smp_gather(t) for t in tensor)
  886. elif isinstance(tensor, dict):
  887. return type(tensor)({k: smp_gather(v) for k, v in tensor.items()})
  888. elif not isinstance(tensor, torch.Tensor):
  889. raise TypeError(
  890. f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
  891. )
  892. all_tensors = smp.allgather(tensor, smp.CommGroup.DP_GROUP)
  893. all_tensors = [atleast_1d(t) for t in all_tensors]
  894. return torch.cat([t.cpu() for t in all_tensors], dim=0)
  895. def smp_nested_concat(tensor):
  896. if isinstance(tensor, (list, tuple)):
  897. return type(tensor)(smp_nested_concat(t) for t in tensor)
  898. elif isinstance(tensor, dict):
  899. return type(tensor)({k: smp_nested_concat(v) for k, v in tensor.items()})
  900. # It doesn't seem possible to check here if `tensor` is a StepOutput because StepOutput lives in `smp.step`
  901. # which is also the name of the decorator so Python is confused.
  902. return tensor.detach().concat().cpu()
  903. @dataclass
  904. class AcceleratorConfig:
  905. """
  906. A subset of arguments relating to the underlying [`accelerate.Accelerator`]
  907. implementation utilized in the `Trainer` that can be customized.
  908. Mostly relating to data.
  909. Parameters:
  910. split_batches (`bool`, *optional*, defaults to `False`):
  911. Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If
  912. `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a
  913. round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set
  914. in your script multiplied by the number of processes.
  915. dispatch_batches (`bool`, *optional*):
  916. If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process
  917. and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose
  918. underlying dataset is an `IterableDataset`, `False` otherwise.
  919. even_batches (`bool`, *optional*, defaults to `True`):
  920. If set to `True`, in cases where the total batch size across all processes does not exactly divide the
  921. dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among
  922. all workers.
  923. use_seedable_sampler (`bool`, *optional*, defaults to `True`):
  924. Whether or not use a fully seedable random sampler ([`accelerate.data_loader.SeedableRandomSampler`]). Ensures
  925. training results are fully reproducible using a different sampling technique. While seed-to-seed results
  926. may differ, on average the differences are negligible when using multiple different seeds to compare. Should
  927. also be ran with [`~utils.set_seed`] for the best results.
  928. gradient_accumulation_kwargs (`dict`, *optional*):
  929. Additional kwargs to configure gradient accumulation, see [`accelerate.utils.GradientAccumulationPlugin`].
  930. Any of the following (optional) keys are acceptable:
  931. num_steps (`int`): Will take precedence over [`~.TrainingArguments.gradient_accumulation_steps`] if
  932. the latter is set to 1, otherwise an exception will be raised.
  933. sync_each_batch (`bool`): Whether to synchronize the gradients at each data batch.
  934. The [`accelerate.utils.GradientAccumulationPlugin`] default is `False`.
  935. non_blocking (`bool`, *optional*, defaults to `False`):
  936. Whether to use non-blocking CUDA calls to help minimize synchronization during
  937. distributed training with prepared `DataLoader` inputs being moved to device.
  938. Best if used with `pin_memory=True` in the `TrainingArguments`.
  939. use_configured_state (`bool*, *optional*, defaults to `False`):
  940. Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined
  941. before calling `TrainingArguments`. If `True`, an `Accelerator` or `PartialState`
  942. must be initialized. May lead to issues using sweeps or hyperparameter tuning.
  943. """
  944. # Data related arguments
  945. split_batches: bool = field(
  946. default=False,
  947. metadata={
  948. "help": "Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If"
  949. " `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a"
  950. " round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set"
  951. " in your script multiplied by the number of processes."
  952. },
  953. )
  954. dispatch_batches: bool | None = field(
  955. default=None,
  956. metadata={
  957. "help": "If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process"
  958. " and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose"
  959. " underlying dataset is an `IterableDataslet`, `False` otherwise."
  960. },
  961. )
  962. even_batches: bool = field(
  963. default=True,
  964. metadata={
  965. "help": "If set to `True`, in cases where the total batch size across all processes does not exactly divide the"
  966. " dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among"
  967. " all workers."
  968. },
  969. )
  970. use_seedable_sampler: bool = field(
  971. default=True,
  972. metadata={
  973. "help": "Whether or not use a fully seedable random sampler ([`accelerate.data_loader.SeedableRandomSampler`])."
  974. "Ensures training results are fully reproducible using a different sampling technique. "
  975. "While seed-to-seed results may differ, on average the differences are negligible when using"
  976. "multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results."
  977. },
  978. )
  979. non_blocking: bool = field(
  980. default=False,
  981. metadata={
  982. "help": "Whether to use non-blocking CUDA calls to help minimize synchronization during "
  983. "distributed training with prepared `DataLoader` inputs being moved to device. "
  984. "Best if used with `pin_memory=True` in the `TrainingArguments`. Requires accelerate "
  985. "v0.30.0."
  986. },
  987. )
  988. gradient_accumulation_kwargs: dict | None = field(
  989. default=None,
  990. metadata={
  991. "help": "Additional kwargs to configure gradient accumulation, see [`accelerate.utils.GradientAccumulationPlugin`]. "
  992. "Any of the following (optional) keys are acceptable: "
  993. " num_steps (`int`): Will take precedence over [`~.TrainingArguments.gradient_accumulation_steps`] if "
  994. " the latter is set to 1, otherwise an exception will be raised. "
  995. " sync_each_batch (`bool`): Whether to synchronize the gradients at each data batch. "
  996. " The [`accelerate.utils.GradientAccumulationPlugin`] default is `False`."
  997. },
  998. )
  999. use_configured_state: bool = field(
  1000. default=False,
  1001. metadata={
  1002. "help": "Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined before calling `TrainingArguments`."
  1003. "If `True`, an `Accelerator` or `PartialState` must be initialized. May lead to issues using sweeps or hyperparameter tuning."
  1004. },
  1005. )
  1006. @classmethod
  1007. def from_json_file(cls, json_file):
  1008. # Check if exists
  1009. open_file = io.open if os.path.exists(json_file) else open
  1010. with open_file(json_file, "r", encoding="utf-8") as f:
  1011. config_dict = json.load(f)
  1012. # Check for keys and load sensible defaults
  1013. extra_keys = sorted(key for key in config_dict if key not in cls.__dataclass_fields__)
  1014. if len(extra_keys) > 0:
  1015. raise ValueError(
  1016. f"The config file at {json_file} had unknown keys ({extra_keys}), please try upgrading your `transformers`"
  1017. " version or fix (and potentially remove these keys) from your config file."
  1018. )
  1019. return cls(**config_dict)
  1020. def to_dict(self):
  1021. return copy.deepcopy(self.__dict__)
  1022. def pop(self, key, default=None):
  1023. return self.__dict__.pop(key, default)
  1024. class LayerWiseDummyOptimizer(torch.optim.Optimizer):
  1025. """
  1026. For Layer-wise optimizers such as GaLoRE optimizer, the optimization
  1027. step is already done through the post gradient hooks. Therefore
  1028. the trick is to create a dummy optimizer that can take arbitrary
  1029. args and kwargs and return a no-op during training.
  1030. Initial idea from @hiyouga in LLaMA-Factory:
  1031. https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba
  1032. """
  1033. def __init__(self, optimizer_dict=None, **kwargs):
  1034. dummy_tensor = torch.randn(1, 1)
  1035. self.optimizer_dict = optimizer_dict
  1036. super().__init__([dummy_tensor], {"lr": kwargs.get("lr", 1e-03)})
  1037. def zero_grad(self, set_to_none: bool = True) -> None:
  1038. pass
  1039. def step(self, closure=None) -> float | None:
  1040. pass
  1041. class LayerWiseDummyScheduler(LRScheduler):
  1042. """
  1043. For Layer-wise optimizers such as GaLoRE optimizer, the optimization and scheduling step
  1044. are already done through the post gradient hooks. Therefore
  1045. the trick is to create a dummy scheduler that can take arbitrary
  1046. args and kwargs and return a no-op during training.
  1047. """
  1048. def __init__(self, *args, **kwargs):
  1049. self.default_lr = kwargs["lr"]
  1050. optimizer = LayerWiseDummyOptimizer(**kwargs)
  1051. last_epoch = -1
  1052. super().__init__(optimizer, last_epoch)
  1053. def get_lr(self):
  1054. # default value
  1055. lrs = [self.default_lr]
  1056. # we take each lr in the parameters if they exist, assumes the optimizer to be the `LayerWiseDummyOptimizer`
  1057. if self.optimizer is not None:
  1058. param_wise_lrs = [
  1059. [group["lr"] for group in optim.param_groups] for optim in self.optimizer.optimizer_dict.values()
  1060. ]
  1061. lrs = list(chain(*param_wise_lrs))
  1062. return lrs
  1063. def _get_closed_form_lr(self):
  1064. return self.base_lrs
  1065. def set_rng_state_for_device(device_name, device_module, checkpoint_rng_state, is_distributed):
  1066. """Helper to set RNG state for a specific device type (CUDA, NPU, MLU, MUSA)"""
  1067. device_state_key = device_name.lower()
  1068. err_template = "Didn't manage to set back the RNG states of the {backend} because of the following error:\n {exception}\nThis won't yield the same results as if the training had not been interrupted."
  1069. try:
  1070. if is_distributed:
  1071. device_module.random.set_rng_state_all(checkpoint_rng_state[device_state_key])
  1072. else:
  1073. device_module.random.set_rng_state(checkpoint_rng_state[device_state_key])
  1074. except Exception as e:
  1075. # Log error if setting RNG state fails
  1076. logger.error(err_template.format(backend=device_name, exception=e))
  1077. def safe_globals():
  1078. """
  1079. Context manager to allowlist numpy objects for torch.load with weights_only=True.
  1080. Starting from version 2.4 PyTorch introduces a check for the objects loaded
  1081. with torch.load(weights_only=True). Starting from 2.6 weights_only=True becomes
  1082. a default and requires allowlisting of objects being loaded.
  1083. See: https://github.com/pytorch/pytorch/pull/137602
  1084. See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
  1085. See: https://github.com/huggingface/accelerate/pull/3036
  1086. """
  1087. if version.parse(torch.__version__).release < version.parse("2.6").release:
  1088. return contextlib.nullcontext()
  1089. np_core = np._core if version.parse(np.__version__) >= version.parse("2.0.0") else np.core
  1090. allowlist = [np_core.multiarray._reconstruct, np.ndarray, np.dtype]
  1091. # numpy >1.25 defines numpy.dtypes.UInt32DType, but below works for
  1092. # all versions of numpy
  1093. allowlist += [type(np.dtype(np.uint32))]
  1094. return torch.serialization.safe_globals(allowlist)