trainer_callback.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784
  1. # Copyright 2020-present the HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Callbacks to use with the Trainer class and customize the training loop.
  16. """
  17. import dataclasses
  18. import json
  19. import math
  20. from dataclasses import dataclass
  21. import numpy as np
  22. from tqdm.auto import tqdm
  23. from .trainer_utils import IntervalStrategy, SaveStrategy, has_length
  24. from .training_args import TrainingArguments
  25. from .utils import logging
  26. logger = logging.get_logger(__name__)
  27. @dataclass
  28. class TrainerState:
  29. """
  30. A class containing the [`Trainer`] inner state that will be saved along the model and optimizer when checkpointing
  31. and passed to the [`TrainerCallback`].
  32. <Tip>
  33. In all this class, one step is to be understood as one update step. When using gradient accumulation, one update
  34. step may require several forward and backward passes: if you use `gradient_accumulation_steps=n`, then one update
  35. step requires going through *n* batches.
  36. </Tip>
  37. Args:
  38. epoch (`float`, *optional*):
  39. Only set during training, will represent the epoch the training is at (the decimal part being the
  40. percentage of the current epoch completed).
  41. global_step (`int`, *optional*, defaults to 0):
  42. During training, represents the number of update steps completed.
  43. max_steps (`int`, *optional*, defaults to 0):
  44. The number of update steps to do during the current training.
  45. logging_steps (`int`, *optional*, defaults to 500):
  46. Log every X updates steps
  47. eval_steps (`int`, *optional*):
  48. Run an evaluation every X steps.
  49. save_steps (`int`, *optional*, defaults to 500):
  50. Save checkpoint every X updates steps.
  51. train_batch_size (`int`, *optional*):
  52. The batch size for the training dataloader. Only needed when
  53. `auto_find_batch_size` has been used.
  54. num_input_tokens_seen (`int`, *optional*, defaults to 0):
  55. When tracking the inputs tokens, the number of tokens seen during training (number of input tokens, not the
  56. number of prediction tokens).
  57. total_flos (`float`, *optional*, defaults to 0):
  58. The total number of floating operations done by the model since the beginning of training (stored as floats
  59. to avoid overflow).
  60. log_history (`list[dict[str, float]]`, *optional*):
  61. The list of logs done since the beginning of training.
  62. best_metric (`float`, *optional*):
  63. When tracking the best model, the value of the best metric encountered so far.
  64. best_global_step (`int`, *optional*):
  65. When tracking the best model, the step at which the best metric was encountered.
  66. Used for setting `best_model_checkpoint`.
  67. best_model_checkpoint (`str`, *optional*):
  68. When tracking the best model, the value of the name of the checkpoint for the best model encountered so
  69. far.
  70. is_local_process_zero (`bool`, *optional*, defaults to `True`):
  71. Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
  72. several machines) main process.
  73. is_world_process_zero (`bool`, *optional*, defaults to `True`):
  74. Whether or not this process is the global main process (when training in a distributed fashion on several
  75. machines, this is only going to be `True` for one process).
  76. is_hyper_param_search (`bool`, *optional*, defaults to `False`):
  77. Whether we are in the process of a hyper parameter search using Trainer.hyperparameter_search. This will
  78. impact the way data will be logged in TensorBoard.
  79. stateful_callbacks (`list[StatefulTrainerCallback]`, *optional*):
  80. Callbacks attached to the `Trainer` that should have their states be saved or restored.
  81. Relevant callbacks should implement a `state` and `from_state` function.
  82. """
  83. epoch: float = 0
  84. global_step: int = 0
  85. max_steps: int = 0
  86. logging_steps: int = 500
  87. eval_steps: int = 500
  88. save_steps: int = 500
  89. train_batch_size: int | None = None
  90. num_train_epochs: int = 0
  91. num_input_tokens_seen: int = 0
  92. total_flos: float = 0
  93. log_history: list[dict[str, float]] = None
  94. best_metric: float | None = None
  95. best_global_step: int | None = None
  96. best_model_checkpoint: str | None = None
  97. is_local_process_zero: bool = True
  98. is_world_process_zero: bool = True
  99. is_hyper_param_search: bool = False
  100. trial_name: str | None = None
  101. trial_params: dict[str, str | float | int | bool] | None = None
  102. stateful_callbacks: list["TrainerCallback"] | None = None
  103. def __post_init__(self):
  104. if self.log_history is None:
  105. self.log_history = []
  106. if self.stateful_callbacks is None:
  107. self.stateful_callbacks = {}
  108. elif isinstance(self.stateful_callbacks, dict):
  109. # We are loading the callbacks in from the state file, no need to process them
  110. pass
  111. else:
  112. # Saveable callbacks get stored as dict of kwargs
  113. stateful_callbacks = {}
  114. for callback in self.stateful_callbacks:
  115. if not isinstance(callback, (ExportableState)):
  116. raise TypeError(
  117. f"All callbacks passed to be saved must inherit `ExportableState`, but received {type(callback)}"
  118. )
  119. name = callback.__class__.__name__
  120. if name in stateful_callbacks:
  121. # We can have multiple versions of the same callback
  122. # if so, we store them as a list of states to restore
  123. if not isinstance(stateful_callbacks[name], list):
  124. stateful_callbacks[name] = [stateful_callbacks[name]]
  125. stateful_callbacks[name].append(callback.state())
  126. else:
  127. stateful_callbacks[name] = callback.state()
  128. self.stateful_callbacks = stateful_callbacks
  129. def save_to_json(self, json_path: str):
  130. """Save the content of this instance in JSON format inside `json_path`."""
  131. json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"
  132. with open(json_path, "w", encoding="utf-8") as f:
  133. f.write(json_string)
  134. @classmethod
  135. def load_from_json(cls, json_path: str):
  136. """Create an instance from the content of `json_path`."""
  137. with open(json_path, encoding="utf-8") as f:
  138. text = f.read()
  139. return cls(**json.loads(text))
  140. def compute_steps(self, args, max_steps):
  141. """
  142. Calculates and stores the absolute value for logging,
  143. eval, and save steps based on if it was a proportion
  144. or not.
  145. """
  146. for step_kind in ("logging", "eval", "save"):
  147. num_steps = getattr(args, f"{step_kind}_steps")
  148. if num_steps is not None:
  149. if num_steps < 1:
  150. num_steps = math.ceil(max_steps * num_steps)
  151. setattr(self, f"{step_kind}_steps", num_steps)
  152. def init_training_references(self, trainer, max_steps, num_train_epochs, trial):
  153. """
  154. Stores the initial training references needed in `self`
  155. """
  156. if trainer.hp_name is not None and trainer._trial is not None:
  157. # use self._trial because the Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
  158. # parameter to Train when using DDP.
  159. self.trial_name = trainer.hp_name(trainer._trial)
  160. self.trial_params = None
  161. if trial is not None:
  162. from transformers.integrations import hp_params
  163. self.trial_params = hp_params(trial)
  164. self.max_steps = max_steps
  165. self.num_train_epochs = num_train_epochs
  166. self.is_local_process_zero = trainer.is_local_process_zero()
  167. self.is_world_process_zero = trainer.is_world_process_zero()
  168. class ExportableState:
  169. """
  170. A class for objects that include the ability to have its state
  171. be saved during `Trainer._save_checkpoint` and loaded back in during
  172. `Trainer._load_from_checkpoint`.
  173. These must implement a `state` function that gets called during the respective
  174. Trainer function call. It should only include parameters and attributes needed to
  175. recreate the state at a particular time, to avoid utilizing pickle/maintain standard
  176. file IO writing.
  177. Example:
  178. ```python
  179. class EarlyStoppingCallback(TrainerCallback, ExportableState):
  180. def __init__(self, early_stopping_patience: int = 1, early_stopping_threshold: Optional[float] = 0.0):
  181. self.early_stopping_patience = early_stopping_patience
  182. self.early_stopping_threshold = early_stopping_threshold
  183. # early_stopping_patience_counter denotes the number of times validation metrics failed to improve.
  184. self.early_stopping_patience_counter = 0
  185. def state(self) -> dict:
  186. return {
  187. "args": {
  188. "early_stopping_patience": self.early_stopping_patience,
  189. "early_stopping_threshold": self.early_stopping_threshold,
  190. },
  191. "attributes": {
  192. "early_stopping_patience_counter": self.early_stopping_patience_counter,
  193. }
  194. }
  195. ```"""
  196. def state(self) -> dict:
  197. raise NotImplementedError("You must implement a `state` function to utilize this class.")
  198. @classmethod
  199. def from_state(cls, state):
  200. instance = cls(**state["args"])
  201. for k, v in state["attributes"].items():
  202. setattr(instance, k, v)
  203. return instance
  204. @dataclass
  205. class TrainerControl(ExportableState):
  206. """
  207. A class that handles the [`Trainer`] control flow. This class is used by the [`TrainerCallback`] to activate some
  208. switches in the training loop.
  209. Args:
  210. should_training_stop (`bool`, *optional*, defaults to `False`):
  211. Whether or not the training should be interrupted.
  212. If `True`, this variable will not be set back to `False`. The training will just stop.
  213. should_epoch_stop (`bool`, *optional*, defaults to `False`):
  214. Whether or not the current epoch should be interrupted.
  215. If `True`, this variable will be set back to `False` at the beginning of the next epoch.
  216. should_save (`bool`, *optional*, defaults to `False`):
  217. Whether or not the model should be saved at this step.
  218. If `True`, this variable will be set back to `False` at the beginning of the next step.
  219. should_evaluate (`bool`, *optional*, defaults to `False`):
  220. Whether or not the model should be evaluated at this step.
  221. If `True`, this variable will be set back to `False` at the beginning of the next step.
  222. should_log (`bool`, *optional*, defaults to `False`):
  223. Whether or not the logs should be reported at this step.
  224. If `True`, this variable will be set back to `False` at the beginning of the next step.
  225. """
  226. should_training_stop: bool = False
  227. should_epoch_stop: bool = False
  228. should_save: bool = False
  229. should_evaluate: bool = False
  230. should_log: bool = False
  231. def _new_training(self):
  232. """Internal method that resets the variable for a new training."""
  233. self.should_training_stop = False
  234. def _new_epoch(self):
  235. """Internal method that resets the variable for a new epoch."""
  236. self.should_epoch_stop = False
  237. def _new_step(self):
  238. """Internal method that resets the variable for a new step."""
  239. self.should_save = False
  240. self.should_evaluate = False
  241. self.should_log = False
  242. def state(self) -> dict:
  243. return {
  244. "args": {
  245. "should_training_stop": self.should_training_stop,
  246. "should_epoch_stop": self.should_epoch_stop,
  247. "should_save": self.should_save,
  248. "should_evaluate": self.should_evaluate,
  249. "should_log": self.should_log,
  250. },
  251. "attributes": {},
  252. }
  253. class TrainerCallback:
  254. # no-format
  255. """
  256. A class for objects that will inspect the state of the training loop at some events and take some decisions. At
  257. each of those events the following arguments are available:
  258. Args:
  259. args ([`TrainingArguments`]):
  260. The training arguments used to instantiate the [`Trainer`].
  261. state ([`TrainerState`]):
  262. The current state of the [`Trainer`].
  263. control ([`TrainerControl`]):
  264. The object that is returned to the [`Trainer`] and can be used to make some decisions.
  265. model ([`PreTrainedModel`] or `torch.nn.Module`):
  266. The model being trained.
  267. processing_class ([`PreTrainedTokenizer` or `BaseImageProcessor` or `ProcessorMixin` or `FeatureExtractionMixin`]):
  268. The processing class used for encoding the data. Can be a tokenizer, a processor, an image processor or a feature extractor.
  269. optimizer (`torch.optim.Optimizer`):
  270. The optimizer used for the training steps.
  271. lr_scheduler (`torch.optim.lr_scheduler.LambdaLR`):
  272. The scheduler used for setting the learning rate.
  273. train_dataloader (`torch.utils.data.DataLoader`, *optional*):
  274. The current dataloader used for training.
  275. eval_dataloader (`torch.utils.data.DataLoader`, *optional*):
  276. The current dataloader used for evaluation.
  277. metrics (`dict[str, float]`):
  278. The metrics computed by the last evaluation phase.
  279. Those are only accessible in the event `on_evaluate`.
  280. logs (`dict[str, float]`):
  281. The values to log.
  282. Those are only accessible in the event `on_log`.
  283. The `control` object is the only one that can be changed by the callback, in which case the event that changes it
  284. should return the modified version.
  285. The argument `args`, `state` and `control` are positionals for all events, all the others are grouped in `kwargs`.
  286. You can unpack the ones you need in the signature of the event using them. As an example, see the code of the
  287. simple [`~transformers.PrinterCallback`].
  288. Example:
  289. ```python
  290. class PrinterCallback(TrainerCallback):
  291. def on_log(self, args, state, control, logs=None, **kwargs):
  292. _ = logs.pop("total_flos", None)
  293. if state.is_local_process_zero:
  294. print(logs)
  295. ```"""
  296. def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
  297. """
  298. Event called at the end of the initialization of the [`Trainer`].
  299. """
  300. def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
  301. """
  302. Event called at the beginning of training.
  303. """
  304. def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
  305. """
  306. Event called at the end of training.
  307. """
  308. def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
  309. """
  310. Event called at the beginning of an epoch.
  311. """
  312. def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
  313. """
  314. Event called at the end of an epoch.
  315. """
  316. def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
  317. """
  318. Event called at the beginning of a training step. If using gradient accumulation, one training step might take
  319. several inputs.
  320. """
  321. def on_pre_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
  322. """
  323. Event called before the optimizer step but after gradient clipping. Useful for monitoring gradients.
  324. """
  325. def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
  326. """
  327. Event called after the optimizer step but before gradients are zeroed out. Useful for monitoring gradients.
  328. """
  329. def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
  330. """
  331. Event called at the end of an substep during gradient accumulation.
  332. """
  333. def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
  334. """
  335. Event called at the end of a training step. If using gradient accumulation, one training step might take
  336. several inputs.
  337. """
  338. def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
  339. """
  340. Event called after an evaluation phase.
  341. """
  342. def on_predict(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics, **kwargs):
  343. """
  344. Event called after a successful prediction.
  345. """
  346. def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
  347. """
  348. Event called after a checkpoint save.
  349. """
  350. def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
  351. """
  352. Event called after logging the last logs.
  353. """
  354. def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
  355. """
  356. Event called after a prediction step.
  357. """
  358. def on_push_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
  359. """
  360. Event called before pushing the model to the hub, at the beginning of Trainer.push_to_hub and Trainer._push_from_checkpoint.
  361. """
  362. class CallbackHandler(TrainerCallback):
  363. """Internal class that just calls the list of callbacks in order."""
  364. def __init__(self, callbacks, model, processing_class, optimizer, lr_scheduler):
  365. self.callbacks = []
  366. for cb in callbacks:
  367. self.add_callback(cb)
  368. self.model = model
  369. self.processing_class = processing_class
  370. self.optimizer = optimizer
  371. self.lr_scheduler = lr_scheduler
  372. self.train_dataloader = None
  373. self.eval_dataloader = None
  374. if not any(isinstance(cb, DefaultFlowCallback) for cb in self.callbacks):
  375. logger.warning(
  376. "The Trainer will not work properly if you don't have a `DefaultFlowCallback` in its callbacks. You\n"
  377. + "should add one before training with `trainer.add_callback(DefaultFlowCallback). The current list of"
  378. + "callbacks is\n:"
  379. + self.callback_list
  380. )
  381. def add_callback(self, callback):
  382. cb = callback() if isinstance(callback, type) else callback
  383. cb_class = callback if isinstance(callback, type) else callback.__class__
  384. if cb_class in [c.__class__ for c in self.callbacks]:
  385. logger.warning(
  386. f"You are adding a {cb_class} to the callbacks of this Trainer, but there is already one. The current"
  387. + "list of callbacks is\n:"
  388. + self.callback_list
  389. )
  390. self.callbacks.append(cb)
  391. def pop_callback(self, callback):
  392. if isinstance(callback, type):
  393. for cb in self.callbacks:
  394. if isinstance(cb, callback):
  395. self.callbacks.remove(cb)
  396. return cb
  397. else:
  398. for cb in self.callbacks:
  399. if cb == callback:
  400. self.callbacks.remove(cb)
  401. return cb
  402. def remove_callback(self, callback):
  403. if isinstance(callback, type):
  404. for cb in self.callbacks:
  405. if isinstance(cb, callback):
  406. self.callbacks.remove(cb)
  407. return
  408. else:
  409. self.callbacks.remove(callback)
  410. @property
  411. def callback_list(self):
  412. return "\n".join(cb.__class__.__name__ for cb in self.callbacks)
  413. def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
  414. return self.call_event("on_init_end", args, state, control)
  415. def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
  416. control.should_training_stop = False
  417. return self.call_event("on_train_begin", args, state, control)
  418. def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
  419. return self.call_event("on_train_end", args, state, control)
  420. def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
  421. control.should_epoch_stop = False
  422. return self.call_event("on_epoch_begin", args, state, control)
  423. def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
  424. return self.call_event("on_epoch_end", args, state, control)
  425. def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
  426. control.should_log = False
  427. control.should_evaluate = False
  428. control.should_save = False
  429. return self.call_event("on_step_begin", args, state, control)
  430. def on_pre_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
  431. return self.call_event("on_pre_optimizer_step", args, state, control)
  432. def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
  433. return self.call_event("on_optimizer_step", args, state, control)
  434. def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
  435. return self.call_event("on_substep_end", args, state, control)
  436. def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
  437. return self.call_event("on_step_end", args, state, control)
  438. def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics):
  439. control.should_evaluate = False
  440. return self.call_event("on_evaluate", args, state, control, metrics=metrics)
  441. def on_predict(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics):
  442. return self.call_event("on_predict", args, state, control, metrics=metrics)
  443. def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
  444. control.should_save = False
  445. return self.call_event("on_save", args, state, control)
  446. def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs):
  447. control.should_log = False
  448. return self.call_event("on_log", args, state, control, logs=logs)
  449. def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
  450. return self.call_event("on_prediction_step", args, state, control)
  451. def on_push_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
  452. return self.call_event("on_push_begin", args, state, control, **kwargs)
  453. def call_event(self, event, args, state, control, **kwargs):
  454. for callback in self.callbacks:
  455. result = getattr(callback, event)(
  456. args,
  457. state,
  458. control,
  459. model=self.model,
  460. processing_class=self.processing_class,
  461. optimizer=self.optimizer,
  462. lr_scheduler=self.lr_scheduler,
  463. train_dataloader=self.train_dataloader,
  464. eval_dataloader=self.eval_dataloader,
  465. **kwargs,
  466. )
  467. # A Callback can skip the return of `control` if it doesn't change it.
  468. if result is not None:
  469. control = result
  470. return control
  471. class DefaultFlowCallback(TrainerCallback):
  472. """
  473. A [`TrainerCallback`] that handles the default flow of the training loop for logs, evaluation and checkpoints.
  474. """
  475. def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
  476. # Log
  477. if state.global_step == 1 and args.logging_first_step:
  478. control.should_log = True
  479. if args.logging_strategy == IntervalStrategy.STEPS and state.global_step % state.logging_steps == 0:
  480. control.should_log = True
  481. # Evaluate
  482. if (
  483. args.eval_strategy == IntervalStrategy.STEPS
  484. and state.global_step % state.eval_steps == 0
  485. and args.eval_delay <= state.global_step
  486. ):
  487. control.should_evaluate = True
  488. # Save
  489. if (
  490. args.save_strategy == SaveStrategy.STEPS
  491. and state.save_steps > 0
  492. and state.global_step % state.save_steps == 0
  493. ):
  494. control.should_save = True
  495. # End training
  496. if state.global_step >= state.max_steps:
  497. control.should_training_stop = True
  498. # Evaluate at the end if we have a step-based eval strategy and this step
  499. # wasn't already going to be evaluated (to avoid duplicate evaluation).
  500. if (
  501. args.eval_strategy == IntervalStrategy.STEPS
  502. and state.global_step % state.eval_steps != 0
  503. and args.eval_delay <= state.global_step
  504. ):
  505. control.should_evaluate = True
  506. # Save the model at the end if we have a save strategy
  507. if args.save_strategy == SaveStrategy.STEPS:
  508. control.should_save = True
  509. return control
  510. def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
  511. # Log
  512. if args.logging_strategy == IntervalStrategy.EPOCH:
  513. control.should_log = True
  514. # Evaluate
  515. if args.eval_strategy == IntervalStrategy.EPOCH and args.eval_delay <= state.epoch:
  516. control.should_evaluate = True
  517. # Save
  518. if args.save_strategy == SaveStrategy.EPOCH:
  519. control.should_save = True
  520. return control
  521. class ProgressCallback(TrainerCallback):
  522. """
  523. A [`TrainerCallback`] that displays the progress of training or evaluation.
  524. You can modify `max_str_len` to control how long strings are truncated when logging.
  525. """
  526. def __init__(self, max_str_len: int = 100):
  527. """
  528. Initialize the callback with optional max_str_len parameter to control string truncation length.
  529. Args:
  530. max_str_len (`int`):
  531. Maximum length of strings to display in logs.
  532. Longer strings will be truncated with a message.
  533. """
  534. self.training_bar = None
  535. self.prediction_bar = None
  536. self.max_str_len = max_str_len
  537. def on_train_begin(self, args, state, control, **kwargs):
  538. if state.is_world_process_zero:
  539. self.training_bar = tqdm(total=state.max_steps, dynamic_ncols=True)
  540. self.current_step = 0
  541. def on_step_end(self, args, state, control, **kwargs):
  542. if state.is_world_process_zero:
  543. self.training_bar.update(state.global_step - self.current_step)
  544. self.current_step = state.global_step
  545. def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
  546. if state.is_world_process_zero and has_length(eval_dataloader):
  547. if self.prediction_bar is None:
  548. self.prediction_bar = tqdm(
  549. total=len(eval_dataloader), leave=self.training_bar is None, dynamic_ncols=True
  550. )
  551. self.prediction_bar.update(1)
  552. def on_evaluate(self, args, state, control, **kwargs):
  553. if state.is_world_process_zero:
  554. if self.prediction_bar is not None:
  555. self.prediction_bar.close()
  556. self.prediction_bar = None
  557. def on_predict(self, args, state, control, **kwargs):
  558. if state.is_world_process_zero:
  559. if self.prediction_bar is not None:
  560. self.prediction_bar.close()
  561. self.prediction_bar = None
  562. def on_log(self, args, state, control, logs=None, **kwargs):
  563. if state.is_world_process_zero and self.training_bar is not None:
  564. # make a shallow copy of logs so we can mutate the fields copied
  565. # but avoid doing any value pickling.
  566. shallow_logs = {}
  567. for k, v in logs.items():
  568. if isinstance(v, str) and len(v) > self.max_str_len:
  569. shallow_logs[k] = (
  570. f"[String too long to display, length: {len(v)} > {self.max_str_len}. "
  571. "Consider increasing `max_str_len` if needed.]"
  572. )
  573. if isinstance(v, float):
  574. # Format floats for better readability
  575. shallow_logs[k] = f"{v:.4g}"
  576. else:
  577. shallow_logs[k] = v
  578. _ = shallow_logs.pop("total_flos", None)
  579. self.training_bar.write(str(shallow_logs))
  580. def on_train_end(self, args, state, control, **kwargs):
  581. if state.is_world_process_zero:
  582. self.training_bar.close()
  583. self.training_bar = None
  584. class PrinterCallback(TrainerCallback):
  585. """
  586. A bare [`TrainerCallback`] that just prints the logs.
  587. """
  588. def on_log(self, args, state, control, logs=None, **kwargs):
  589. _ = logs.pop("total_flos", None)
  590. if state.is_local_process_zero:
  591. if logs is not None:
  592. logs = {k: (f"{v:.4g}" if isinstance(v, float) else v) for k, v in logs.items()}
  593. print(logs)
  594. class EarlyStoppingCallback(TrainerCallback, ExportableState):
  595. """
  596. A [`TrainerCallback`] that handles early stopping.
  597. Args:
  598. early_stopping_patience (`int`):
  599. Use with `metric_for_best_model` to stop training when the specified metric worsens for
  600. `early_stopping_patience` evaluation calls.
  601. early_stopping_threshold(`float`, *optional*):
  602. Use with TrainingArguments `metric_for_best_model` and `early_stopping_patience` to denote how much the
  603. specified metric must improve to satisfy early stopping conditions. `
  604. This callback depends on [`TrainingArguments`] argument *load_best_model_at_end* functionality to set best_metric
  605. in [`TrainerState`]. Note that if the [`TrainingArguments`] argument *save_steps* differs from *eval_steps*, the
  606. early stopping will not occur until the next save step.
  607. """
  608. def __init__(self, early_stopping_patience: int = 1, early_stopping_threshold: float | None = 0.0):
  609. self.early_stopping_patience = early_stopping_patience
  610. self.early_stopping_threshold = early_stopping_threshold
  611. # early_stopping_patience_counter denotes the number of times validation metrics failed to improve.
  612. self.early_stopping_patience_counter = 0
  613. def check_metric_value(self, args, state, control, metric_value):
  614. # best_metric is set by code for load_best_model
  615. operator = np.greater if args.greater_is_better else np.less
  616. if state.best_metric is None or (
  617. operator(metric_value, state.best_metric)
  618. and abs(metric_value - state.best_metric) > self.early_stopping_threshold
  619. ):
  620. self.early_stopping_patience_counter = 0
  621. else:
  622. self.early_stopping_patience_counter += 1
  623. def on_train_begin(self, args, state, control, **kwargs):
  624. if not args.load_best_model_at_end:
  625. logger.warning(
  626. "Using EarlyStoppingCallback without load_best_model_at_end=True. "
  627. "Once training is finished, the best model will not be loaded automatically."
  628. )
  629. assert args.metric_for_best_model is not None, (
  630. "EarlyStoppingCallback requires metric_for_best_model to be defined"
  631. )
  632. assert args.eval_strategy != IntervalStrategy.NO, (
  633. "EarlyStoppingCallback requires IntervalStrategy of steps or epoch"
  634. )
  635. def on_evaluate(self, args, state, control, metrics, **kwargs):
  636. metric_to_check = args.metric_for_best_model
  637. if not metric_to_check.startswith("eval_"):
  638. metric_to_check = f"eval_{metric_to_check}"
  639. metric_value = metrics.get(metric_to_check)
  640. if metric_value is None:
  641. logger.warning(
  642. f"early stopping required metric_for_best_model, but did not find {metric_to_check} so early stopping"
  643. " is disabled"
  644. )
  645. return
  646. self.check_metric_value(args, state, control, metric_value)
  647. if self.early_stopping_patience_counter >= self.early_stopping_patience:
  648. control.should_training_stop = True
  649. def state(self) -> dict:
  650. return {
  651. "args": {
  652. "early_stopping_patience": self.early_stopping_patience,
  653. "early_stopping_threshold": self.early_stopping_threshold,
  654. },
  655. "attributes": {
  656. "early_stopping_patience_counter": self.early_stopping_patience_counter,
  657. },
  658. }