integration_utils.py 112 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608
  1. # Copyright 2020 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Integrations with other Python libraries.
  16. """
  17. import copy
  18. import functools
  19. import importlib.metadata
  20. import importlib.util
  21. import json
  22. import numbers
  23. import os
  24. import re
  25. import shutil
  26. import sys
  27. import tempfile
  28. import warnings
  29. from dataclasses import fields
  30. from enum import Enum
  31. from pathlib import Path
  32. from typing import TYPE_CHECKING, Any, Literal
  33. import numpy as np
  34. import packaging.version
  35. if os.getenv("WANDB_MODE") == "offline":
  36. print("[INFO] Running in WANDB offline mode")
  37. from .. import PreTrainedModel, TrainingArguments
  38. from .. import __version__ as version
  39. from ..utils import (
  40. PushToHubMixin,
  41. flatten_dict,
  42. is_datasets_available,
  43. is_pandas_available,
  44. is_torch_available,
  45. logging,
  46. )
  47. logger = logging.get_logger(__name__)
  48. if is_torch_available():
  49. import torch
  50. # comet_ml requires to be imported before any ML frameworks
  51. _MIN_COMET_VERSION = "3.43.2"
  52. try:
  53. _comet_version = importlib.metadata.version("comet_ml")
  54. _is_comet_installed = True
  55. _is_comet_recent_enough = packaging.version.parse(_comet_version) >= packaging.version.parse(_MIN_COMET_VERSION)
  56. # Check if the Comet API Key is set
  57. import comet_ml
  58. if comet_ml.config.get_config("comet.api_key") is not None:
  59. _is_comet_configured = True
  60. else:
  61. _is_comet_configured = False
  62. except (importlib.metadata.PackageNotFoundError, ImportError, ValueError, TypeError, AttributeError, KeyError):
  63. _comet_version = None
  64. _is_comet_installed = False
  65. _is_comet_recent_enough = False
  66. _is_comet_configured = False
  67. _has_neptune = (
  68. importlib.util.find_spec("neptune") is not None or importlib.util.find_spec("neptune-client") is not None
  69. )
  70. if TYPE_CHECKING and _has_neptune:
  71. try:
  72. _neptune_version = importlib.metadata.version("neptune")
  73. logger.info(f"Neptune version {_neptune_version} available.")
  74. except importlib.metadata.PackageNotFoundError:
  75. try:
  76. _neptune_version = importlib.metadata.version("neptune-client")
  77. logger.info(f"Neptune-client version {_neptune_version} available.")
  78. except importlib.metadata.PackageNotFoundError:
  79. _has_neptune = False
  80. from .. import modelcard # noqa: E402
  81. from ..trainer_callback import ProgressCallback, TrainerCallback # noqa: E402
  82. from ..trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402
  83. from ..training_args import ParallelMode # noqa: E402
  84. from ..utils import ENV_VARS_TRUE_VALUES, is_torch_xla_available # noqa: E402
  85. # Integration functions:
  86. def is_wandb_available():
  87. if importlib.util.find_spec("wandb") is not None:
  88. import wandb
  89. # wandb might still be detected by find_spec after an uninstall (leftover files or metadata), but not actually
  90. # import correctly. To confirm it's fully installed and usable, we check for a key attribute like "run".
  91. return hasattr(wandb, "run")
  92. else:
  93. return False
  94. def is_trackio_available():
  95. return importlib.util.find_spec("trackio") is not None
  96. def is_clearml_available():
  97. return importlib.util.find_spec("clearml") is not None
  98. def is_comet_available():
  99. if _is_comet_installed is False:
  100. return False
  101. if _is_comet_recent_enough is False:
  102. logger.warning(
  103. "comet_ml version %s is installed, but version %s or higher is required. "
  104. "Please update comet_ml to the latest version to enable Comet logging with pip install 'comet-ml>=%s'.",
  105. _comet_version,
  106. _MIN_COMET_VERSION,
  107. _MIN_COMET_VERSION,
  108. )
  109. return False
  110. if _is_comet_configured is False:
  111. logger.warning(
  112. "comet_ml is installed but the Comet API Key is not configured. "
  113. "Please set the `COMET_API_KEY` environment variable to enable Comet logging. "
  114. "Check out the documentation for other ways of configuring it: "
  115. "https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/#set-the-api-key"
  116. )
  117. return False
  118. return True
  119. def is_tensorboard_available():
  120. return importlib.util.find_spec("tensorboard") is not None or importlib.util.find_spec("tensorboardX") is not None
  121. def is_optuna_available():
  122. return importlib.util.find_spec("optuna") is not None
  123. def is_ray_available():
  124. return importlib.util.find_spec("ray") is not None
  125. def is_ray_tune_available():
  126. if not is_ray_available():
  127. return False
  128. return importlib.util.find_spec("ray.tune") is not None
  129. def is_azureml_available():
  130. if importlib.util.find_spec("azureml") is None:
  131. return False
  132. if importlib.util.find_spec("azureml.core") is None:
  133. return False
  134. return importlib.util.find_spec("azureml.core.run") is not None
  135. def is_mlflow_available():
  136. if os.getenv("DISABLE_MLFLOW_INTEGRATION", "FALSE").upper() == "TRUE":
  137. return False
  138. return importlib.util.find_spec("mlflow") is not None
  139. def is_dagshub_available():
  140. return None not in [importlib.util.find_spec("dagshub"), importlib.util.find_spec("mlflow")]
  141. def is_neptune_available():
  142. return _has_neptune
  143. def is_codecarbon_available():
  144. return importlib.util.find_spec("codecarbon") is not None
  145. def is_flytekit_available():
  146. return importlib.util.find_spec("flytekit") is not None
  147. def is_flyte_deck_standard_available():
  148. if not is_flytekit_available():
  149. return False
  150. return importlib.util.find_spec("flytekitplugins.deck") is not None
  151. def is_dvclive_available():
  152. return importlib.util.find_spec("dvclive") is not None
  153. def is_swanlab_available():
  154. return importlib.util.find_spec("swanlab") is not None
  155. def is_kubeflow_available():
  156. if os.getenv("DISABLE_KUBEFLOW_INTEGRATION", "FALSE").upper() == "TRUE":
  157. return False
  158. return os.getenv("KUBEFLOW_TRAINER_SERVER_URL") is not None
  159. def hp_params(trial):
  160. if is_optuna_available():
  161. import optuna
  162. if isinstance(trial, optuna.trial.BaseTrial):
  163. return trial.params
  164. if is_ray_tune_available():
  165. if isinstance(trial, dict):
  166. return trial
  167. if is_wandb_available():
  168. if isinstance(trial, dict):
  169. return trial
  170. raise RuntimeError(f"Unknown type for trial {trial.__class__}")
  171. def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
  172. import optuna
  173. from accelerate.utils.memory import release_memory
  174. if trainer.args.process_index == 0:
  175. def _objective(trial: optuna.Trial, checkpoint_dir=None):
  176. checkpoint = None
  177. if checkpoint_dir:
  178. for subdir in os.listdir(checkpoint_dir):
  179. if subdir.startswith(PREFIX_CHECKPOINT_DIR):
  180. checkpoint = os.path.join(checkpoint_dir, subdir)
  181. trainer.objective = None
  182. if trainer.args.world_size > 1:
  183. if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
  184. raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.")
  185. trainer.hp_space(trial)
  186. fixed_trial = optuna.trial.FixedTrial(trial.params, trial.number)
  187. trial_main_rank_list = [fixed_trial]
  188. torch.distributed.broadcast_object_list(trial_main_rank_list, src=0)
  189. trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
  190. else:
  191. trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
  192. # If there hasn't been any evaluation during the training loop.
  193. if getattr(trainer, "objective", None) is None:
  194. metrics = trainer.evaluate()
  195. trainer.objective = trainer.compute_objective(metrics)
  196. # Free GPU memory
  197. trainer.model_wrapped, trainer.model = release_memory(trainer.model_wrapped, trainer.model)
  198. trainer.accelerator.clear()
  199. return trainer.objective
  200. timeout = kwargs.pop("timeout", None)
  201. n_jobs = kwargs.pop("n_jobs", 1)
  202. gc_after_trial = kwargs.pop("gc_after_trial", False)
  203. catch = kwargs.pop("catch", ())
  204. directions = direction if isinstance(direction, list) else None
  205. direction = None if directions is not None else direction
  206. study = optuna.create_study(direction=direction, directions=directions, **kwargs)
  207. study.optimize(
  208. _objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs, gc_after_trial=gc_after_trial, catch=catch
  209. )
  210. if not study._is_multi_objective():
  211. best_trial = study.best_trial
  212. return BestRun(str(best_trial.number), best_trial.value, best_trial.params)
  213. else:
  214. best_trials = study.best_trials
  215. return [BestRun(str(best.number), best.values, best.params) for best in best_trials]
  216. else:
  217. for i in range(n_trials):
  218. trainer.objective = None
  219. trial_main_rank_list = [None]
  220. if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
  221. raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.")
  222. torch.distributed.broadcast_object_list(trial_main_rank_list, src=0)
  223. trainer.train(resume_from_checkpoint=None, trial=trial_main_rank_list[0])
  224. # If there hasn't been any evaluation during the training loop.
  225. if getattr(trainer, "objective", None) is None:
  226. metrics = trainer.evaluate()
  227. trainer.objective = trainer.compute_objective(metrics)
  228. return None
  229. def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
  230. """
  231. Environment:
  232. - **RAY_SCOPE** (`str`, *optional*, defaults to `"last"`):
  233. The scope to use when doing hyperparameter search with Ray. By default, `"last"` will be used. Ray
  234. will then use the last checkpoint of all trials, compare those, and select the best one. However,
  235. other options are also available. See the Ray documentation (https://docs.ray.io/en/latest/tune/api_docs/analysis.html#ray.tune.ExperimentAnalysis.get_best_trial)
  236. for more options
  237. """
  238. import ray.tune
  239. def _objective(trial: dict, local_trainer):
  240. try:
  241. from transformers.utils.notebook import NotebookProgressCallback
  242. if local_trainer.pop_callback(NotebookProgressCallback):
  243. local_trainer.add_callback(ProgressCallback)
  244. except ModuleNotFoundError:
  245. pass
  246. local_trainer.objective = None
  247. checkpoint = ray.tune.get_checkpoint()
  248. if checkpoint:
  249. # Upon trial resume, the local_trainer's objective gets reset to None.
  250. # If `local_trainer.train` is a noop (training has already reached
  251. # the target number of epochs/steps), then this would
  252. # trigger an unnecessary extra checkpoint at the end of training.
  253. # -> Set the objective to a dummy value upon resume as a workaround.
  254. local_trainer.objective = "objective"
  255. with checkpoint.as_directory() as checkpoint_dir:
  256. checkpoint_path = next(Path(checkpoint_dir).glob(f"{PREFIX_CHECKPOINT_DIR}*")).as_posix()
  257. local_trainer.train(resume_from_checkpoint=checkpoint_path, trial=trial)
  258. else:
  259. local_trainer.train(trial=trial)
  260. # If there hasn't been any evaluation during the training loop.
  261. if getattr(local_trainer, "objective", None) is None:
  262. metrics = local_trainer.evaluate()
  263. local_trainer.objective = local_trainer.compute_objective(metrics)
  264. metrics.update({"objective": local_trainer.objective, "done": True})
  265. with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
  266. local_trainer._tune_save_checkpoint(checkpoint_dir=temp_checkpoint_dir)
  267. checkpoint = ray.tune.Checkpoint.from_directory(temp_checkpoint_dir)
  268. ray.tune.report(metrics, checkpoint=checkpoint)
  269. if not trainer._memory_tracker.skip_memory_metrics:
  270. from ..trainer_utils import TrainerMemoryTracker
  271. logger.warning(
  272. "Memory tracking for your Trainer is currently "
  273. "enabled. Automatically disabling the memory tracker "
  274. "since the memory tracker is not serializable."
  275. )
  276. trainer._memory_tracker = TrainerMemoryTracker(skip_memory_metrics=True)
  277. # The model and TensorBoard writer do not pickle so we have to remove them (if they exists)
  278. # while doing the ray hp search.
  279. _tb_writer = trainer.pop_callback(TensorBoardCallback)
  280. trainer.model = None
  281. # Setup default `resources_per_trial`.
  282. if "resources_per_trial" not in kwargs:
  283. # Default to 1 CPU and 1 GPU (if applicable) per trial.
  284. kwargs["resources_per_trial"] = {"cpu": 1}
  285. if trainer.args.n_gpu > 0:
  286. kwargs["resources_per_trial"]["gpu"] = 1
  287. resource_msg = "1 CPU" + (" and 1 GPU" if trainer.args.n_gpu > 0 else "")
  288. logger.info(
  289. "No `resources_per_trial` arg was passed into "
  290. "`hyperparameter_search`. Setting it to a default value "
  291. f"of {resource_msg} for each trial."
  292. )
  293. # Make sure each trainer only uses GPUs that were allocated per trial.
  294. gpus_per_trial = kwargs["resources_per_trial"].get("gpu", 0)
  295. trainer.args._n_gpu = gpus_per_trial
  296. # Setup default `progress_reporter`.
  297. if "progress_reporter" not in kwargs:
  298. from ray.tune import CLIReporter
  299. kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"])
  300. if "scheduler" in kwargs:
  301. from ray.tune.schedulers import ASHAScheduler, HyperBandForBOHB, MedianStoppingRule, PopulationBasedTraining
  302. # Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting.
  303. if isinstance(
  304. kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining)
  305. ) and (not trainer.args.do_eval or trainer.args.eval_strategy == IntervalStrategy.NO):
  306. raise RuntimeError(
  307. "You are using {cls} as a scheduler but you haven't enabled evaluation during training. "
  308. "This means your trials will not report intermediate results to Ray Tune, and "
  309. "can thus not be stopped early or used to exploit other trials parameters. "
  310. "If this is what you want, do not use {cls}. If you would like to use {cls}, "
  311. "make sure you pass `do_eval=True` and `eval_strategy='steps'` in the "
  312. "Trainer `args`.".format(cls=type(kwargs["scheduler"]).__name__)
  313. )
  314. trainable = ray.tune.with_parameters(_objective, local_trainer=trainer)
  315. @functools.wraps(trainable)
  316. def dynamic_modules_import_trainable(*args, **kwargs):
  317. """
  318. Wrapper around `tune.with_parameters` to ensure datasets_modules are loaded on each Actor.
  319. Without this, an ImportError will be thrown. See https://github.com/huggingface/transformers/issues/11565.
  320. Assumes that `_objective`, defined above, is a function.
  321. """
  322. if is_datasets_available() and packaging.version.parse(
  323. importlib.metadata.version("datasets")
  324. ) < packaging.version.parse("4.0.0"):
  325. import datasets.load
  326. dynamic_modules_path = os.path.join(datasets.load.init_dynamic_modules(), "__init__.py")
  327. # load dynamic_modules from path
  328. spec = importlib.util.spec_from_file_location("datasets_modules", dynamic_modules_path)
  329. datasets_modules = importlib.util.module_from_spec(spec)
  330. sys.modules[spec.name] = datasets_modules
  331. spec.loader.exec_module(datasets_modules)
  332. return trainable(*args, **kwargs)
  333. # special attr set by tune.with_parameters
  334. if hasattr(trainable, "__mixins__"):
  335. dynamic_modules_import_trainable.__mixins__ = trainable.__mixins__
  336. analysis = ray.tune.run(
  337. dynamic_modules_import_trainable,
  338. config=trainer.hp_space(None),
  339. num_samples=n_trials,
  340. **kwargs,
  341. )
  342. ray_scope = os.getenv("RAY_SCOPE", "last")
  343. best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3], scope=ray_scope)
  344. best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config, analysis)
  345. if _tb_writer is not None:
  346. trainer.add_callback(_tb_writer)
  347. return best_run
  348. def run_hp_search_wandb(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
  349. if not is_wandb_available():
  350. raise ImportError("This function needs wandb installed: `pip install wandb`")
  351. import wandb
  352. # add WandbCallback if not already added in trainer callbacks
  353. reporting_to_wandb = False
  354. for callback in trainer.callback_handler.callbacks:
  355. if isinstance(callback, WandbCallback):
  356. reporting_to_wandb = True
  357. break
  358. if not reporting_to_wandb:
  359. trainer.add_callback(WandbCallback())
  360. trainer.args.report_to = ["wandb"]
  361. best_trial = {"run_id": None, "objective": None, "hyperparameters": None}
  362. sweep_id = kwargs.pop("sweep_id", None)
  363. project = kwargs.pop("project", None)
  364. name = kwargs.pop("name", None)
  365. entity = kwargs.pop("entity", None)
  366. metric = kwargs.pop("metric", "eval/loss")
  367. sweep_config = trainer.hp_space(None)
  368. sweep_config["metric"]["goal"] = direction
  369. sweep_config["metric"]["name"] = metric
  370. if name:
  371. sweep_config["name"] = name
  372. def _objective():
  373. run = wandb.run if wandb.run else wandb.init()
  374. trainer.state.trial_name = run.name
  375. run.config.update({"assignments": {}, "metric": metric})
  376. config = wandb.config
  377. trainer.objective = None
  378. trainer.train(resume_from_checkpoint=None, trial=vars(config)["_items"])
  379. # If there hasn't been any evaluation during the training loop.
  380. if getattr(trainer, "objective", None) is None:
  381. metrics = trainer.evaluate()
  382. trainer.objective = trainer.compute_objective(metrics)
  383. format_metrics = rewrite_logs(metrics)
  384. if metric not in format_metrics:
  385. logger.warning(
  386. f"Provided metric {metric} not found. This might result in unexpected sweeps charts. The available"
  387. f" metrics are {format_metrics.keys()}"
  388. )
  389. best_score = False
  390. if best_trial["run_id"] is not None:
  391. if direction == "minimize":
  392. best_score = trainer.objective < best_trial["objective"]
  393. elif direction == "maximize":
  394. best_score = trainer.objective > best_trial["objective"]
  395. if best_score or best_trial["run_id"] is None:
  396. best_trial["run_id"] = run.id
  397. best_trial["objective"] = trainer.objective
  398. best_trial["hyperparameters"] = dict(config)
  399. return trainer.objective
  400. if not sweep_id:
  401. sweep_id = wandb.sweep(sweep_config, project=project, entity=entity)
  402. else:
  403. import wandb.env
  404. if entity:
  405. wandb.env.set_entity(entity)
  406. wandb.env.set_project(project)
  407. logger.info(f"wandb sweep id - {sweep_id}")
  408. wandb.agent(sweep_id, function=_objective, count=n_trials)
  409. return BestRun(best_trial["run_id"], best_trial["objective"], best_trial["hyperparameters"], sweep_id)
  410. def get_available_reporting_integrations():
  411. integrations = []
  412. if is_azureml_available() and not is_mlflow_available():
  413. integrations.append("azure_ml")
  414. if is_comet_available():
  415. integrations.append("comet_ml")
  416. if is_dagshub_available():
  417. integrations.append("dagshub")
  418. if is_dvclive_available():
  419. integrations.append("dvclive")
  420. if is_mlflow_available():
  421. integrations.append("mlflow")
  422. if is_neptune_available():
  423. integrations.append("neptune")
  424. if is_tensorboard_available():
  425. integrations.append("tensorboard")
  426. if is_wandb_available():
  427. integrations.append("wandb")
  428. if is_codecarbon_available():
  429. integrations.append("codecarbon")
  430. if is_clearml_available():
  431. integrations.append("clearml")
  432. if is_swanlab_available():
  433. integrations.append("swanlab")
  434. if is_trackio_available():
  435. integrations.append("trackio")
  436. if is_kubeflow_available():
  437. integrations.append("kubeflow")
  438. return integrations
  439. def rewrite_logs(d):
  440. new_d = {}
  441. eval_prefix = "eval_"
  442. eval_prefix_len = len(eval_prefix)
  443. test_prefix = "test_"
  444. test_prefix_len = len(test_prefix)
  445. for k, v in d.items():
  446. if k.startswith(eval_prefix):
  447. new_d["eval/" + k[eval_prefix_len:]] = v
  448. elif k.startswith(test_prefix):
  449. new_d["test/" + k[test_prefix_len:]] = v
  450. else:
  451. new_d["train/" + k] = v
  452. return new_d
  453. def default_logdir() -> str:
  454. """
  455. Same default as PyTorch
  456. """
  457. import socket
  458. from datetime import datetime
  459. current_time = datetime.now().strftime("%b%d_%H-%M-%S")
  460. return os.path.join("runs", current_time + "_" + socket.gethostname())
  461. class TensorBoardCallback(TrainerCallback):
  462. """
  463. A [`TrainerCallback`] that sends the logs to [TensorBoard](https://www.tensorflow.org/tensorboard).
  464. Args:
  465. tb_writer (`SummaryWriter`, *optional*):
  466. The writer to use. Will instantiate one if not set.
  467. Environment:
  468. - **TENSORBOARD_LOGGING_DIR** (`str`, *optional*, defaults to `None`):
  469. The logging dir to log the results. Default value is os.path.join(args.output_dir, default_logdir())
  470. """
  471. def __init__(self, tb_writer=None):
  472. if not is_tensorboard_available():
  473. raise RuntimeError(
  474. "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or"
  475. " install tensorboardX."
  476. )
  477. try:
  478. from torch.utils.tensorboard import SummaryWriter
  479. except ImportError:
  480. from tensorboardX import SummaryWriter
  481. self._SummaryWriter = SummaryWriter
  482. self.tb_writer = tb_writer
  483. self.logging_dir = os.getenv("TENSORBOARD_LOGGING_DIR", None)
  484. if self.logging_dir is not None:
  485. self.logging_dir = os.path.expanduser(self.logging_dir)
  486. def _init_summary_writer(self, args):
  487. if self._SummaryWriter is not None:
  488. self.tb_writer = self._SummaryWriter(log_dir=self.logging_dir)
  489. def on_train_begin(self, args, state, control, **kwargs):
  490. if not state.is_world_process_zero:
  491. return
  492. if state.is_hyper_param_search:
  493. trial_name = state.trial_name
  494. if trial_name is not None:
  495. # overwrite logging dir for trials
  496. self.logging_dir = os.path.join(args.output_dir, default_logdir(), trial_name)
  497. if self.logging_dir is None:
  498. self.logging_dir = os.path.join(args.output_dir, default_logdir())
  499. if self.tb_writer is None:
  500. self._init_summary_writer(args)
  501. if self.tb_writer is not None:
  502. self.tb_writer.add_text("args", args.to_json_string())
  503. if "model" in kwargs:
  504. model = kwargs["model"]
  505. if hasattr(model, "config") and model.config is not None:
  506. model_config_json = model.config.to_json_string()
  507. self.tb_writer.add_text("model_config", model_config_json)
  508. def on_log(self, args, state, control, logs=None, **kwargs):
  509. if not state.is_world_process_zero:
  510. return
  511. if self.tb_writer is None:
  512. self._init_summary_writer(args)
  513. if self.tb_writer is not None:
  514. logs = rewrite_logs(logs)
  515. for k, v in logs.items():
  516. if isinstance(v, (int, float)):
  517. self.tb_writer.add_scalar(k, v, state.global_step)
  518. elif isinstance(v, str):
  519. self.tb_writer.add_text(k, v, state.global_step)
  520. else:
  521. logger.warning(
  522. "Trainer is attempting to log a value of "
  523. f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
  524. "This invocation of Tensorboard's writer.add_scalar() "
  525. "is incorrect so we dropped this attribute."
  526. )
  527. self.tb_writer.flush()
  528. def on_train_end(self, args, state, control, **kwargs):
  529. if self.tb_writer:
  530. self.tb_writer.close()
  531. self.tb_writer = None
  532. def save_model_architecture_to_file(model: Any, output_dir: str):
  533. with open(f"{output_dir}/model_architecture.txt", "w+") as f:
  534. if isinstance(model, PreTrainedModel):
  535. print(model, file=f)
  536. elif is_torch_available() and (
  537. isinstance(model, (torch.nn.Module, PushToHubMixin)) and hasattr(model, "base_model")
  538. ):
  539. print(model, file=f)
  540. class WandbLogModel(str, Enum):
  541. """Enum of possible log model values in W&B."""
  542. CHECKPOINT = "checkpoint"
  543. END = "end"
  544. FALSE = "false"
  545. @property
  546. def is_enabled(self) -> bool:
  547. """Check if the value corresponds to a state where the `WANDB_LOG_MODEL` setting is enabled."""
  548. return self in (WandbLogModel.CHECKPOINT, WandbLogModel.END)
  549. @classmethod
  550. def _missing_(cls, value: Any) -> "WandbLogModel":
  551. if not isinstance(value, str):
  552. raise TypeError(f"Expecting to have a string `WANDB_LOG_MODEL` setting, but got {type(value)}")
  553. logger.warning(
  554. f"Received unrecognized `WANDB_LOG_MODEL` setting value={value}; so disabling `WANDB_LOG_MODEL`"
  555. )
  556. return WandbLogModel.FALSE
  557. class WandbCallback(TrainerCallback):
  558. """
  559. A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/).
  560. """
  561. def __init__(self):
  562. has_wandb = is_wandb_available()
  563. if not has_wandb:
  564. raise RuntimeError("WandbCallback requires wandb to be installed. Run `pip install wandb`.")
  565. import wandb
  566. self._wandb = wandb
  567. self._initialized = False
  568. self._log_model = WandbLogModel(os.getenv("WANDB_LOG_MODEL", "false"))
  569. def setup(self, args, state, model, **kwargs):
  570. """
  571. Setup the optional Weights & Biases (*wandb*) integration.
  572. One can subclass and override this method to customize the setup if needed. Find more information
  573. [here](https://docs.wandb.ai/guides/integrations/huggingface). You can also override the following environment
  574. variables:
  575. Environment:
  576. - **WANDB_LOG_MODEL** (`str`, *optional*, defaults to `"false"`):
  577. Whether to log model and checkpoints during training. Can be `"end"`, `"checkpoint"` or `"false"`. If set
  578. to `"end"`, the model will be uploaded at the end of training. If set to `"checkpoint"`, the checkpoint
  579. will be uploaded every `args.save_steps` . If set to `"false"`, the model will not be uploaded. Use along
  580. with [`~transformers.TrainingArguments.load_best_model_at_end`] to upload best model.
  581. - **WANDB_WATCH** (`str`, *optional* defaults to `"false"`):
  582. Can be `"gradients"`, `"all"`, `"parameters"`, or `"false"`. Set to `"all"` to log gradients and
  583. parameters.
  584. - **WANDB_PROJECT** (`str`, *optional*, defaults to `"huggingface"`):
  585. Set this to a custom string to store results in a different project.
  586. """
  587. if self._wandb is None:
  588. return
  589. self._initialized = True
  590. # prepare to handle potential configuration issues during setup
  591. from wandb.sdk.lib.config_util import ConfigError as WandbConfigError
  592. if state.is_world_process_zero:
  593. combined_dict = {**args.to_dict()}
  594. if hasattr(model, "config") and model.config is not None:
  595. model_config = model.config if isinstance(model.config, dict) else model.config.to_dict()
  596. combined_dict = {**model_config, **combined_dict}
  597. if hasattr(model, "peft_config") and model.peft_config is not None:
  598. peft_config = model.peft_config
  599. combined_dict = {"peft_config": peft_config, **combined_dict}
  600. trial_name = state.trial_name
  601. init_args = {}
  602. if trial_name is not None:
  603. init_args["name"] = trial_name
  604. init_args["group"] = args.run_name or args.output_dir
  605. elif args.run_name is not None:
  606. init_args["name"] = args.run_name
  607. if args.run_name == args.output_dir:
  608. self._wandb.termwarn(
  609. "The `run_name` is currently set to the same value as `TrainingArguments.output_dir`. If this was "
  610. "not intended, please specify a different run name by setting the `TrainingArguments.run_name` parameter.",
  611. repeat=False,
  612. )
  613. if self._wandb.run is None:
  614. self._wandb.init(
  615. project=os.getenv("WANDB_PROJECT", "huggingface"),
  616. **init_args,
  617. )
  618. # add config parameters (run may have been created manually)
  619. self._wandb.config.update(combined_dict or {}, allow_val_change=True)
  620. # define default x-axis (for latest wandb versions)
  621. if getattr(self._wandb, "define_metric", None):
  622. self._wandb.define_metric("train/global_step")
  623. self._wandb.define_metric("*", step_metric="train/global_step", step_sync=True)
  624. # keep track of model topology and gradients, unsupported on TPU
  625. _watch_model = os.getenv("WANDB_WATCH", "false")
  626. if not is_torch_xla_available() and _watch_model in ("all", "parameters", "gradients"):
  627. self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps))
  628. self._wandb.run._label(code="transformers_trainer")
  629. # add number of model parameters to wandb config
  630. try:
  631. self._wandb.config["model/num_parameters"] = model.num_parameters()
  632. except AttributeError:
  633. logger.info(
  634. "Could not log the number of model parameters in Weights & Biases due to an AttributeError."
  635. )
  636. except WandbConfigError:
  637. logger.warning(
  638. "A ConfigError was raised whilst setting the number of model parameters in Weights & Biases config."
  639. )
  640. # log the initial model architecture to an artifact
  641. if self._log_model.is_enabled:
  642. with tempfile.TemporaryDirectory() as temp_dir:
  643. model_name = (
  644. f"model-{self._wandb.run.id}"
  645. if (args.run_name is None or args.run_name == args.output_dir)
  646. else f"model-{self._wandb.run.name}"
  647. )
  648. model_artifact = self._wandb.Artifact(
  649. name=model_name,
  650. type="model",
  651. metadata={
  652. "model_config": model.config.to_dict() if hasattr(model, "config") else None,
  653. "num_parameters": self._wandb.config.get("model/num_parameters"),
  654. "initial_model": True,
  655. },
  656. )
  657. # add the architecture to a separate text file
  658. save_model_architecture_to_file(model, temp_dir)
  659. for f in Path(temp_dir).glob("*"):
  660. if f.is_file():
  661. with model_artifact.new_file(f.name, mode="wb") as fa:
  662. fa.write(f.read_bytes())
  663. self._wandb.run.log_artifact(model_artifact, aliases=["base_model"])
  664. badge_markdown = (
  665. f'[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge'
  666. f'-28.svg" alt="Visualize in Weights & Biases" width="20'
  667. f'0" height="32"/>]({self._wandb.run.url})'
  668. )
  669. modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
  670. def on_train_begin(self, args, state, control, model=None, **kwargs):
  671. if self._wandb is None:
  672. return
  673. hp_search = state.is_hyper_param_search
  674. if hp_search:
  675. self._wandb.finish()
  676. self._initialized = False
  677. args.run_name = None
  678. if not self._initialized:
  679. self.setup(args, state, model, **kwargs)
  680. def on_train_end(self, args: TrainingArguments, state, control, model=None, processing_class=None, **kwargs):
  681. if self._wandb is None:
  682. return
  683. if self._log_model.is_enabled and self._initialized and state.is_world_process_zero:
  684. from ..trainer import Trainer
  685. args_for_fake = copy.deepcopy(args)
  686. args_for_fake.deepspeed = None
  687. args_for_fake.deepspeed_plugin = None
  688. fake_trainer = Trainer(
  689. args=args_for_fake, model=model, processing_class=processing_class, eval_dataset=["fake"]
  690. )
  691. with tempfile.TemporaryDirectory() as temp_dir:
  692. fake_trainer.save_model(temp_dir)
  693. metadata = (
  694. {
  695. k: v
  696. for k, v in dict(self._wandb.summary).items()
  697. if isinstance(v, numbers.Number) and not k.startswith("_")
  698. }
  699. if not args.load_best_model_at_end
  700. else {
  701. f"eval/{args.metric_for_best_model}": state.best_metric,
  702. "train/total_floss": state.total_flos,
  703. "model/num_parameters": self._wandb.config.get("model/num_parameters"),
  704. }
  705. )
  706. metadata["final_model"] = True
  707. logger.info("Logging model artifacts. ...")
  708. model_name = (
  709. f"model-{self._wandb.run.id}"
  710. if (args.run_name is None or args.run_name == args.output_dir)
  711. else f"model-{self._wandb.run.name}"
  712. )
  713. # add the model architecture to a separate text file
  714. save_model_architecture_to_file(model, temp_dir)
  715. artifact = self._wandb.Artifact(name=model_name, type="model", metadata=metadata)
  716. for f in Path(temp_dir).glob("*"):
  717. if f.is_file():
  718. with artifact.new_file(f.name, mode="wb") as fa:
  719. fa.write(f.read_bytes())
  720. self._wandb.run.log_artifact(artifact, aliases=["final_model"])
  721. def on_log(self, args, state, control, model=None, logs=None, **kwargs):
  722. single_value_scalars = [
  723. "train_runtime",
  724. "train_samples_per_second",
  725. "train_steps_per_second",
  726. "train_loss",
  727. "total_flos",
  728. ]
  729. if self._wandb is None:
  730. return
  731. if not self._initialized:
  732. self.setup(args, state, model)
  733. if state.is_world_process_zero:
  734. for k, v in logs.items():
  735. if k in single_value_scalars:
  736. self._wandb.run.summary[k] = v
  737. non_scalar_logs = {k: v for k, v in logs.items() if k not in single_value_scalars}
  738. non_scalar_logs = rewrite_logs(non_scalar_logs)
  739. self._wandb.log({**non_scalar_logs, "train/global_step": state.global_step})
  740. def on_save(self, args, state, control, **kwargs):
  741. if self._log_model == WandbLogModel.CHECKPOINT and self._initialized and state.is_world_process_zero:
  742. checkpoint_metadata = {
  743. k: v
  744. for k, v in dict(self._wandb.summary).items()
  745. if isinstance(v, numbers.Number) and not k.startswith("_")
  746. }
  747. checkpoint_metadata["model/num_parameters"] = self._wandb.config.get("model/num_parameters")
  748. ckpt_dir = f"checkpoint-{state.global_step}"
  749. artifact_path = os.path.join(args.output_dir, ckpt_dir)
  750. logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. ...")
  751. checkpoint_name = (
  752. f"model-{self._wandb.run.id}"
  753. if (args.run_name is None or args.run_name == args.output_dir)
  754. else f"model-{self._wandb.run.name}"
  755. )
  756. artifact = self._wandb.Artifact(name=checkpoint_name, type="model", metadata=checkpoint_metadata)
  757. artifact.add_dir(artifact_path)
  758. self._wandb.log_artifact(
  759. artifact, aliases=[f"epoch_{round(state.epoch, 2)}", f"checkpoint_global_step_{state.global_step}"]
  760. )
  761. def on_predict(self, args, state, control, metrics, **kwargs):
  762. if self._wandb is None:
  763. return
  764. if not self._initialized:
  765. self.setup(args, state, **kwargs)
  766. if state.is_world_process_zero:
  767. metrics = rewrite_logs(metrics)
  768. self._wandb.log(metrics)
  769. class TrackioCallback(TrainerCallback):
  770. """
  771. A [`TrainerCallback`] that logs metrics to Trackio.
  772. It records training metrics, model (including PEFT) configuration.
  773. **Requires**:
  774. ```bash
  775. pip install trackio
  776. ```
  777. """
  778. SPACE_URL = "https://huggingface.co/spaces/{space_id}"
  779. def __init__(self):
  780. has_trackio = is_trackio_available()
  781. if not has_trackio:
  782. raise RuntimeError("TrackioCallback requires trackio to be installed. Run `pip install trackio`.")
  783. if has_trackio:
  784. import trackio
  785. self._trackio = trackio
  786. self._initialized = False
  787. def setup(self, args, state, model, **kwargs):
  788. """
  789. Setup the optional Trackio integration.
  790. To customize the setup you can also set the arguments `project`, `trackio_space_id` and `hub_private_repo` in
  791. [`TrainingArguments`]. Please refer to the docstring of for more details.
  792. """
  793. if state.is_world_process_zero:
  794. combined_dict = {**args.to_dict()}
  795. if hasattr(model, "config") and model.config is not None:
  796. model_config = model.config if isinstance(model.config, dict) else model.config.to_dict()
  797. combined_dict = {**model_config, **combined_dict}
  798. if hasattr(model, "peft_config") and model.peft_config is not None:
  799. peft_config = model.peft_config
  800. combined_dict = {"peft_config": peft_config, **combined_dict}
  801. self._trackio.init(
  802. project=args.project,
  803. name=args.run_name,
  804. space_id=args.trackio_space_id,
  805. resume="allow",
  806. private=args.hub_private_repo,
  807. )
  808. # Add config parameters (run may have been created manually)
  809. self._trackio.config.update(combined_dict, allow_val_change=True)
  810. # Add number of model parameters to trackio config
  811. try:
  812. self._trackio.config["model/num_parameters"] = model.num_parameters()
  813. except AttributeError:
  814. logger.info("Could not log the number of model parameters in Trackio due to an AttributeError.")
  815. self._initialized = True
  816. def on_train_begin(self, args, state, control, model=None, **kwargs):
  817. if not self._initialized:
  818. self.setup(args, state, model, **kwargs)
  819. def on_train_end(self, args: TrainingArguments, state, control, model=None, processing_class=None, **kwargs):
  820. if state.is_world_process_zero and self._initialized:
  821. self._trackio.finish()
  822. def on_log(self, args, state, control, model=None, logs=None, **kwargs):
  823. single_value_scalars = [
  824. "train_runtime",
  825. "train_samples_per_second",
  826. "train_steps_per_second",
  827. "train_loss",
  828. "total_flos",
  829. ]
  830. if not self._initialized:
  831. self.setup(args, state, model)
  832. if state.is_world_process_zero:
  833. non_scalar_logs = {k: v for k, v in logs.items() if k not in single_value_scalars}
  834. non_scalar_logs = rewrite_logs(non_scalar_logs)
  835. self._trackio.log({**non_scalar_logs, "train/global_step": state.global_step})
  836. def on_save(self, args, state, control, **kwargs):
  837. return
  838. def on_predict(self, args, state, control, metrics, **kwargs):
  839. if self._trackio is None:
  840. return
  841. if not self._initialized:
  842. self.setup(args, state, **kwargs)
  843. if state.is_world_process_zero:
  844. metrics = rewrite_logs(metrics)
  845. self._trackio.log(metrics)
  846. def on_push_begin(self, args, state, control, model, **kwargs):
  847. if not state.is_world_process_zero or self._trackio is None:
  848. return
  849. if (current_project := self._trackio.context_vars.current_project.get()) is None:
  850. return
  851. trackio_version = packaging.version.parse(self._trackio.__version__)
  852. if trackio_version < packaging.version.parse("0.13.0"):
  853. warnings.warn(
  854. "The version of `trackio` that is installed is <=0.13.0, so "
  855. "the local Trackio project will not be pushed to Hugging Face. Run "
  856. "`pip install --upgrade trackio` to fix this."
  857. )
  858. return
  859. space_id = self._trackio.context_vars.current_space_id.get()
  860. if space_id is None:
  861. space_id = self._trackio.sync(current_project, force=True)
  862. space_url = self.SPACE_URL.format(space_id=space_id)
  863. badge_markdown = (
  864. f'<a href="{space_url}" target="_blank"><img src="https://raw.githubusercontent.com/gradio-app/trackio/refs/heads/main/trackio/assets/badge.png" alt="Visualize in Trackio"'
  865. ' title="Visualize in Trackio" style="height: 40px;"/></a>'
  866. )
  867. if badge_markdown not in modelcard.AUTOGENERATED_TRAINER_COMMENT:
  868. modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
  869. trackio_tags = ["trackio", f"trackio:{space_url}"]
  870. if getattr(model, "model_tags", None) is not None:
  871. if "trackio" not in model.model_tags:
  872. model.model_tags.extend(trackio_tags)
  873. else:
  874. model.model_tags = trackio_tags
  875. class CometCallback(TrainerCallback):
  876. """
  877. A [`TrainerCallback`] that sends the logs to [Comet ML](https://www.comet.com/site/).
  878. """
  879. def __init__(self):
  880. if _is_comet_installed is False or _is_comet_recent_enough is False:
  881. raise RuntimeError(
  882. f"CometCallback requires comet-ml>={_MIN_COMET_VERSION} to be installed. Run `pip install comet-ml>={_MIN_COMET_VERSION}`."
  883. )
  884. self._initialized = False
  885. self._log_assets = False
  886. self._experiment = None
  887. def setup(self, args, state, model):
  888. """
  889. Setup the optional Comet integration.
  890. Environment:
  891. - **COMET_MODE** (`str`, *optional*, default to `get_or_create`):
  892. Control whether to create and log to a new Comet experiment or append to an existing experiment.
  893. It accepts the following values:
  894. * `get_or_create`: Decides automatically depending if
  895. `COMET_EXPERIMENT_KEY` is set and whether an Experiment
  896. with that key already exists or not.
  897. * `create`: Always create a new Comet Experiment.
  898. * `get`: Always try to append to an Existing Comet Experiment.
  899. Requires `COMET_EXPERIMENT_KEY` to be set.
  900. - **COMET_START_ONLINE** (`bool`, *optional*):
  901. Whether to create an online or offline Experiment.
  902. - **COMET_PROJECT_NAME** (`str`, *optional*):
  903. Comet project name for experiments.
  904. - **COMET_LOG_ASSETS** (`str`, *optional*, defaults to `TRUE`):
  905. Whether or not to log training assets (checkpoints, etc), to Comet. Can be `TRUE`, or
  906. `FALSE`.
  907. For a number of configurable items in the environment, see
  908. [here](https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/#explore-comet-configuration-options).
  909. """
  910. self._initialized = True
  911. log_assets = os.getenv("COMET_LOG_ASSETS", "FALSE").upper()
  912. if log_assets in {"TRUE", "1"}:
  913. self._log_assets = True
  914. if state.is_world_process_zero:
  915. comet_old_mode = os.getenv("COMET_MODE")
  916. mode = None
  917. online = None
  918. if comet_old_mode is not None:
  919. comet_old_mode = comet_old_mode.lower()
  920. if comet_old_mode in ("get", "get_or_create", "create"):
  921. mode = comet_old_mode
  922. elif comet_old_mode:
  923. logger.warning("Invalid COMET_MODE env value %r, Comet logging is disabled", comet_old_mode)
  924. return
  925. # For HPO, we always create a new experiment for each trial
  926. if state.is_hyper_param_search:
  927. if mode is not None:
  928. logger.warning(
  929. "Hyperparameter Search is enabled, forcing the creation of new experiments, COMET_MODE value %r is ignored",
  930. comet_old_mode,
  931. )
  932. mode = "create"
  933. import comet_ml
  934. experiment_config = comet_ml.ExperimentConfig(name=args.run_name)
  935. self._experiment = comet_ml.start(online=online, mode=mode, experiment_config=experiment_config)
  936. self._experiment.__internal_api__set_model_graph__(model, framework="transformers")
  937. params = {"args": args.to_dict()}
  938. if hasattr(model, "config") and model.config is not None:
  939. model_config = model.config.to_dict()
  940. params["config"] = model_config
  941. if hasattr(model, "peft_config") and model.peft_config is not None:
  942. peft_config = model.peft_config
  943. params["peft_config"] = peft_config
  944. self._experiment.__internal_api__log_parameters__(
  945. params, framework="transformers", source="manual", flatten_nested=True
  946. )
  947. if state.is_hyper_param_search:
  948. optimization_id = getattr(state, "trial_name", None)
  949. optimization_params = getattr(state, "trial_params", None)
  950. self._experiment.log_optimization(optimization_id=optimization_id, parameters=optimization_params)
  951. def on_train_begin(self, args, state, control, model=None, **kwargs):
  952. if not self._initialized:
  953. self.setup(args, state, model)
  954. def on_log(self, args, state, control, model=None, logs=None, **kwargs):
  955. if not self._initialized:
  956. self.setup(args, state, model)
  957. if state.is_world_process_zero:
  958. if self._experiment is not None:
  959. rewritten_logs = rewrite_logs(logs)
  960. self._experiment.__internal_api__log_metrics__(
  961. rewritten_logs, step=state.global_step, epoch=state.epoch, framework="transformers"
  962. )
  963. def on_train_end(self, args, state, control, **kwargs):
  964. if self._initialized and state.is_world_process_zero:
  965. if self._experiment is not None:
  966. if self._log_assets is True:
  967. logger.info("Logging checkpoints. This may take time.")
  968. self._experiment.log_asset_folder(
  969. args.output_dir, recursive=True, log_file_name=True, step=state.global_step
  970. )
  971. # We create one experiment per trial in HPO mode
  972. if state.is_hyper_param_search:
  973. self._experiment.clean()
  974. self._initialized = False
  975. def on_predict(self, args, state, control, metrics, **kwargs):
  976. if not self._initialized:
  977. self.setup(args, state, model=None)
  978. if state.is_world_process_zero and self._experiment is not None:
  979. rewritten_metrics = rewrite_logs(metrics)
  980. self._experiment.__internal_api__log_metrics__(
  981. rewritten_metrics, step=state.global_step, epoch=state.epoch, framework="transformers"
  982. )
  983. class AzureMLCallback(TrainerCallback):
  984. """
  985. A [`TrainerCallback`] that sends the logs to [AzureML](https://pypi.org/project/azureml-sdk/).
  986. """
  987. def __init__(self, azureml_run=None):
  988. if not is_azureml_available():
  989. raise RuntimeError("AzureMLCallback requires azureml to be installed. Run `pip install azureml-sdk`.")
  990. self.azureml_run = azureml_run
  991. def on_init_end(self, args, state, control, **kwargs):
  992. from azureml.core.run import Run
  993. if self.azureml_run is None and state.is_world_process_zero:
  994. self.azureml_run = Run.get_context()
  995. def on_log(self, args, state, control, logs=None, **kwargs):
  996. if self.azureml_run and state.is_world_process_zero:
  997. for k, v in logs.items():
  998. if isinstance(v, (int, float)):
  999. self.azureml_run.log(k, v, description=k)
  1000. class MLflowCallback(TrainerCallback):
  1001. """
  1002. A [`TrainerCallback`] that sends the logs to [MLflow](https://www.mlflow.org/). Can be disabled by setting
  1003. environment variable `DISABLE_MLFLOW_INTEGRATION = TRUE`.
  1004. """
  1005. def __init__(self):
  1006. if not is_mlflow_available():
  1007. raise RuntimeError("MLflowCallback requires mlflow to be installed. Run `pip install mlflow`.")
  1008. import mlflow
  1009. self._MAX_PARAM_VAL_LENGTH = mlflow.utils.validation.MAX_PARAM_VAL_LENGTH
  1010. self._MAX_PARAMS_TAGS_PER_BATCH = mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH
  1011. self._initialized = False
  1012. self._auto_end_run = False
  1013. self._log_artifacts = False
  1014. self._ml_flow = mlflow
  1015. def setup(self, args, state, model):
  1016. """
  1017. Setup the optional MLflow integration.
  1018. Environment:
  1019. - **HF_MLFLOW_LOG_ARTIFACTS** (`str`, *optional*):
  1020. Whether to use MLflow `.log_artifact()` facility to log artifacts. This only makes sense if logging to a
  1021. remote server, e.g. s3 or GCS. If set to `True` or *1*, will copy each saved checkpoint on each save in
  1022. [`TrainingArguments`]'s `output_dir` to the local or remote artifact storage. Using it without a remote
  1023. storage will just copy the files to your artifact location.
  1024. - **MLFLOW_TRACKING_URI** (`str`, *optional*):
  1025. Whether to store runs at a specific path or remote server. Unset by default, which skips setting the
  1026. tracking URI entirely.
  1027. - **MLFLOW_EXPERIMENT_NAME** (`str`, *optional*, defaults to `None`):
  1028. Whether to use an MLflow experiment_name under which to launch the run. Default to `None` which will point
  1029. to the `Default` experiment in MLflow. Otherwise, it is a case sensitive name of the experiment to be
  1030. activated. If an experiment with this name does not exist, a new experiment with this name is created.
  1031. - **MLFLOW_TAGS** (`str`, *optional*):
  1032. A string dump of a dictionary of key/value pair to be added to the MLflow run as tags. Example:
  1033. `os.environ['MLFLOW_TAGS']='{"release.candidate": "RC1", "release.version": "2.2.0"}'`.
  1034. - **MLFLOW_NESTED_RUN** (`str`, *optional*):
  1035. Whether to use MLflow nested runs. If set to `True` or *1*, will create a nested run inside the current
  1036. run.
  1037. - **MLFLOW_RUN_ID** (`str`, *optional*):
  1038. Allow to reattach to an existing run which can be useful when resuming training from a checkpoint. When
  1039. `MLFLOW_RUN_ID` environment variable is set, `start_run` attempts to resume a run with the specified run ID
  1040. and other parameters are ignored.
  1041. - **MLFLOW_FLATTEN_PARAMS** (`str`, *optional*, defaults to `False`):
  1042. Whether to flatten the parameters dictionary before logging.
  1043. - **MLFLOW_MAX_LOG_PARAMS** (`int`, *optional*):
  1044. Set the maximum number of parameters to log in the run.
  1045. """
  1046. self._log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
  1047. self._nested_run = os.getenv("MLFLOW_NESTED_RUN", "FALSE").upper() in ENV_VARS_TRUE_VALUES
  1048. self._tracking_uri = os.getenv("MLFLOW_TRACKING_URI", None)
  1049. self._experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", None)
  1050. self._flatten_params = os.getenv("MLFLOW_FLATTEN_PARAMS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
  1051. self._run_id = os.getenv("MLFLOW_RUN_ID", None)
  1052. self._max_log_params = os.getenv("MLFLOW_MAX_LOG_PARAMS", None)
  1053. # "synchronous" flag is only available with mlflow version >= 2.8.0
  1054. # https://github.com/mlflow/mlflow/pull/9705
  1055. # https://github.com/mlflow/mlflow/releases/tag/v2.8.0
  1056. self._async_log = packaging.version.parse(self._ml_flow.__version__) >= packaging.version.parse("2.8.0")
  1057. logger.debug(
  1058. f"MLflow experiment_name={self._experiment_name}, run_name={args.run_name}, nested={self._nested_run},"
  1059. f" tracking_uri={self._tracking_uri}"
  1060. )
  1061. if state.is_world_process_zero:
  1062. if not self._ml_flow.is_tracking_uri_set():
  1063. if self._tracking_uri:
  1064. self._ml_flow.set_tracking_uri(self._tracking_uri)
  1065. logger.debug(f"MLflow tracking URI is set to {self._tracking_uri}")
  1066. else:
  1067. logger.debug(
  1068. "Environment variable `MLFLOW_TRACKING_URI` is not provided and therefore will not be"
  1069. " explicitly set."
  1070. )
  1071. else:
  1072. logger.debug(f"MLflow tracking URI is set to {self._ml_flow.get_tracking_uri()}")
  1073. if self._ml_flow.active_run() is None or self._nested_run or self._run_id:
  1074. if self._experiment_name:
  1075. # Use of set_experiment() ensure that Experiment is created if not exists
  1076. self._ml_flow.set_experiment(self._experiment_name)
  1077. self._ml_flow.start_run(run_name=args.run_name, nested=self._nested_run)
  1078. logger.debug(f"MLflow run started with run_id={self._ml_flow.active_run().info.run_id}")
  1079. self._auto_end_run = True
  1080. combined_dict = args.to_dict()
  1081. if hasattr(model, "config") and model.config is not None:
  1082. model_config = model.config.to_dict()
  1083. combined_dict = {**model_config, **combined_dict}
  1084. combined_dict = flatten_dict(combined_dict) if self._flatten_params else combined_dict
  1085. # remove params that are too long for MLflow
  1086. for name, value in list(combined_dict.items()):
  1087. # internally, all values are converted to str in MLflow
  1088. if len(str(value)) > self._MAX_PARAM_VAL_LENGTH:
  1089. logger.warning(
  1090. f'Trainer is attempting to log a value of "{value}" for key "{name}" as a parameter. MLflow\'s'
  1091. " log_param() only accepts values no longer than 250 characters so we dropped this attribute."
  1092. " You can use `MLFLOW_FLATTEN_PARAMS` environment variable to flatten the parameters and"
  1093. " avoid this message."
  1094. )
  1095. del combined_dict[name]
  1096. # MLflow cannot log more than 100 values in one go, so we have to split it
  1097. combined_dict_items = list(combined_dict.items())
  1098. if self._max_log_params and self._max_log_params.isdigit():
  1099. max_log_params = int(self._max_log_params)
  1100. if max_log_params < len(combined_dict_items):
  1101. logger.debug(
  1102. f"Reducing the number of parameters to log from {len(combined_dict_items)} to {max_log_params}."
  1103. )
  1104. combined_dict_items = combined_dict_items[:max_log_params]
  1105. for i in range(0, len(combined_dict_items), self._MAX_PARAMS_TAGS_PER_BATCH):
  1106. if self._async_log:
  1107. self._ml_flow.log_params(
  1108. dict(combined_dict_items[i : i + self._MAX_PARAMS_TAGS_PER_BATCH]), synchronous=False
  1109. )
  1110. else:
  1111. self._ml_flow.log_params(dict(combined_dict_items[i : i + self._MAX_PARAMS_TAGS_PER_BATCH]))
  1112. mlflow_tags = os.getenv("MLFLOW_TAGS", None)
  1113. if mlflow_tags:
  1114. mlflow_tags = json.loads(mlflow_tags)
  1115. self._ml_flow.set_tags(mlflow_tags)
  1116. self._initialized = True
  1117. def on_train_begin(self, args, state, control, model=None, **kwargs):
  1118. if not self._initialized:
  1119. self.setup(args, state, model)
  1120. def on_log(self, args, state, control, logs, model=None, **kwargs):
  1121. if not self._initialized:
  1122. self.setup(args, state, model)
  1123. if state.is_world_process_zero:
  1124. metrics = {}
  1125. for k, v in logs.items():
  1126. if isinstance(v, (int, float)):
  1127. metrics[k] = v
  1128. elif isinstance(v, torch.Tensor) and v.numel() == 1:
  1129. metrics[k] = v.item()
  1130. else:
  1131. logger.warning(
  1132. f'Trainer is attempting to log a value of "{v}" of type {type(v)} for key "{k}" as a metric. '
  1133. "MLflow's log_metric() only accepts float and int types so we dropped this attribute."
  1134. )
  1135. # sanitize metric names to replace unsupported characters like parentheses
  1136. sanitized_metrics = {re.sub(r"[^0-9A-Za-z_\-\.\ :/]", "_", k): v for k, v in metrics.items()}
  1137. if self._async_log:
  1138. self._ml_flow.log_metrics(metrics=sanitized_metrics, step=state.global_step, synchronous=False)
  1139. else:
  1140. self._ml_flow.log_metrics(metrics=sanitized_metrics, step=state.global_step)
  1141. def on_train_end(self, args, state, control, **kwargs):
  1142. if self._initialized and state.is_world_process_zero:
  1143. if self._auto_end_run and self._ml_flow.active_run():
  1144. self._ml_flow.end_run()
  1145. def on_save(self, args, state, control, **kwargs):
  1146. if self._initialized and state.is_world_process_zero and self._log_artifacts:
  1147. ckpt_dir = f"checkpoint-{state.global_step}"
  1148. artifact_path = os.path.join(args.output_dir, ckpt_dir)
  1149. logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. This may take time.")
  1150. self._ml_flow.pyfunc.log_model(
  1151. ckpt_dir,
  1152. artifacts={"model_path": artifact_path},
  1153. python_model=self._ml_flow.pyfunc.PythonModel(),
  1154. )
  1155. def __del__(self):
  1156. # if the previous run is not terminated correctly, the fluent API will
  1157. # not let you start a new run before the previous one is killed
  1158. if (
  1159. self._auto_end_run
  1160. and callable(getattr(self._ml_flow, "active_run", None))
  1161. and self._ml_flow.active_run() is not None
  1162. ):
  1163. self._ml_flow.end_run()
  1164. class DagsHubCallback(MLflowCallback):
  1165. """
  1166. A [`TrainerCallback`] that logs to [DagsHub](https://dagshub.com/). Extends [`MLflowCallback`]
  1167. """
  1168. def __init__(self):
  1169. super().__init__()
  1170. if not is_dagshub_available():
  1171. raise ImportError("DagsHubCallback requires dagshub to be installed. Run `pip install dagshub`.")
  1172. from dagshub.upload import Repo
  1173. self.Repo = Repo
  1174. def setup(self, *args, **kwargs):
  1175. """
  1176. Setup the DagsHub's Logging integration.
  1177. Environment:
  1178. - **HF_DAGSHUB_LOG_ARTIFACTS** (`str`, *optional*):
  1179. Whether to save the data and model artifacts for the experiment. Default to `False`.
  1180. """
  1181. self.log_artifacts = os.getenv("HF_DAGSHUB_LOG_ARTIFACTS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
  1182. self.name = os.getenv("HF_DAGSHUB_MODEL_NAME") or "main"
  1183. self.remote = os.getenv("MLFLOW_TRACKING_URI")
  1184. self.repo = self.Repo(
  1185. owner=self.remote.split(os.sep)[-2],
  1186. name=self.remote.split(os.sep)[-1].split(".")[0],
  1187. branch=os.getenv("BRANCH") or "main",
  1188. )
  1189. self.path = Path("artifacts")
  1190. if self.remote is None:
  1191. raise RuntimeError(
  1192. "DagsHubCallback requires the `MLFLOW_TRACKING_URI` environment variable to be set. Did you run"
  1193. " `dagshub.init()`?"
  1194. )
  1195. super().setup(*args, **kwargs)
  1196. def on_train_end(self, args, state, control, **kwargs):
  1197. if self.log_artifacts:
  1198. if getattr(self, "train_dataloader", None):
  1199. torch.save(self.train_dataloader.dataset, os.path.join(args.output_dir, "dataset.pt"))
  1200. self.repo.directory(str(self.path)).add_dir(args.output_dir)
  1201. class NeptuneMissingConfiguration(Exception):
  1202. def __init__(self):
  1203. super().__init__(
  1204. """
  1205. ------ Unsupported ---- We were not able to create new runs. You provided a custom Neptune run to
  1206. `NeptuneCallback` with the `run` argument. For the integration to work fully, provide your `api_token` and
  1207. `project` by saving them as environment variables or passing them to the callback.
  1208. """
  1209. )
  1210. class NeptuneCallback(TrainerCallback):
  1211. """TrainerCallback that sends the logs to [Neptune](https://app.neptune.ai).
  1212. > [!WARNING]
  1213. > Neptune integration is deprecated and will be removed in a future version of Transformers. We recommend using
  1214. > other supported experiment tracking integrations.
  1215. Args:
  1216. api_token (`str`, *optional*): Neptune API token obtained upon registration.
  1217. You can leave this argument out if you have saved your token to the `NEPTUNE_API_TOKEN` environment
  1218. variable (strongly recommended). See full setup instructions in the
  1219. [docs](https://docs.neptune.ai/setup/installation).
  1220. project (`str`, *optional*): Name of an existing Neptune project, in the form "workspace-name/project-name".
  1221. You can find and copy the name in Neptune from the project settings -> Properties. If None (default), the
  1222. value of the `NEPTUNE_PROJECT` environment variable is used.
  1223. name (`str`, *optional*): Custom name for the run.
  1224. base_namespace (`str`, *optional*, defaults to "finetuning"): In the Neptune run, the root namespace
  1225. that will contain all of the metadata logged by the callback.
  1226. log_parameters (`bool`, *optional*, defaults to `True`):
  1227. If True, logs all Trainer arguments and model parameters provided by the Trainer.
  1228. log_checkpoints (`str`, *optional*): If "same", uploads checkpoints whenever they are saved by the Trainer.
  1229. If "last", uploads only the most recently saved checkpoint. If "best", uploads the best checkpoint (among
  1230. the ones saved by the Trainer). If `None`, does not upload checkpoints.
  1231. run (`Run`, *optional*): Pass a Neptune run object if you want to continue logging to an existing run.
  1232. Read more about resuming runs in the [docs](https://docs.neptune.ai/logging/to_existing_object).
  1233. **neptune_run_kwargs (*optional*):
  1234. Additional keyword arguments to be passed directly to the
  1235. [`neptune.init_run()`](https://docs.neptune.ai/api/neptune#init_run) function when a new run is created.
  1236. For instructions and examples, see the [Transformers integration
  1237. guide](https://docs.neptune.ai/integrations/transformers) in the Neptune documentation.
  1238. """
  1239. integration_version_key = "source_code/integrations/transformers"
  1240. model_parameters_key = "model_parameters"
  1241. trial_name_key = "trial"
  1242. trial_params_key = "trial_params"
  1243. trainer_parameters_key = "trainer_parameters"
  1244. flat_metrics = {"train/epoch"}
  1245. def __init__(
  1246. self,
  1247. *,
  1248. api_token: str | None = None,
  1249. project: str | None = None,
  1250. name: str | None = None,
  1251. base_namespace: str = "finetuning",
  1252. run=None,
  1253. log_parameters: bool = True,
  1254. log_checkpoints: str | None = None,
  1255. **neptune_run_kwargs,
  1256. ):
  1257. warnings.warn(
  1258. "The NeptuneCallback is deprecated and will be removed in a future version of Transformers. We recommend "
  1259. "using other supported experiment tracking integrations.",
  1260. FutureWarning,
  1261. )
  1262. if not is_neptune_available():
  1263. raise ValueError(
  1264. "NeptuneCallback requires the Neptune client library to be installed. "
  1265. "To install the library, run `pip install neptune`."
  1266. )
  1267. try:
  1268. from neptune import Run
  1269. from neptune.internal.utils import verify_type
  1270. except ImportError:
  1271. from neptune.new.internal.utils import verify_type
  1272. from neptune.new.metadata_containers.run import Run
  1273. verify_type("api_token", api_token, (str, type(None)))
  1274. verify_type("project", project, (str, type(None)))
  1275. verify_type("name", name, (str, type(None)))
  1276. verify_type("base_namespace", base_namespace, str)
  1277. verify_type("run", run, (Run, type(None)))
  1278. verify_type("log_parameters", log_parameters, bool)
  1279. verify_type("log_checkpoints", log_checkpoints, (str, type(None)))
  1280. self._base_namespace_path = base_namespace
  1281. self._log_parameters = log_parameters
  1282. self._log_checkpoints = log_checkpoints
  1283. self._initial_run: Run | None = run
  1284. self._run = None
  1285. self._is_monitoring_run = False
  1286. self._run_id = None
  1287. self._force_reset_monitoring_run = False
  1288. self._init_run_kwargs = {"api_token": api_token, "project": project, "name": name, **neptune_run_kwargs}
  1289. self._volatile_checkpoints_dir = None
  1290. self._should_upload_checkpoint = self._log_checkpoints is not None
  1291. self._recent_checkpoint_path = None
  1292. if self._log_checkpoints in {"last", "best"}:
  1293. self._target_checkpoints_namespace = f"checkpoints/{self._log_checkpoints}"
  1294. self._should_clean_recently_uploaded_checkpoint = True
  1295. else:
  1296. self._target_checkpoints_namespace = "checkpoints"
  1297. self._should_clean_recently_uploaded_checkpoint = False
  1298. def _stop_run_if_exists(self):
  1299. if self._run:
  1300. self._run.stop()
  1301. del self._run
  1302. self._run = None
  1303. def _initialize_run(self, **additional_neptune_kwargs):
  1304. try:
  1305. from neptune import init_run
  1306. from neptune.exceptions import NeptuneMissingApiTokenException, NeptuneMissingProjectNameException
  1307. except ImportError:
  1308. from neptune.new import init_run
  1309. from neptune.new.exceptions import NeptuneMissingApiTokenException, NeptuneMissingProjectNameException
  1310. self._stop_run_if_exists()
  1311. try:
  1312. run_params = additional_neptune_kwargs.copy()
  1313. run_params.update(self._init_run_kwargs)
  1314. self._run = init_run(**run_params)
  1315. self._run_id = self._run["sys/id"].fetch()
  1316. except (NeptuneMissingProjectNameException, NeptuneMissingApiTokenException) as e:
  1317. raise NeptuneMissingConfiguration() from e
  1318. def _use_initial_run(self):
  1319. self._run = self._initial_run
  1320. self._is_monitoring_run = True
  1321. self._run_id = self._run["sys/id"].fetch()
  1322. self._initial_run = None
  1323. def _ensure_run_with_monitoring(self):
  1324. if self._initial_run is not None:
  1325. self._use_initial_run()
  1326. else:
  1327. if not self._force_reset_monitoring_run and self._is_monitoring_run:
  1328. return
  1329. if self._run and not self._is_monitoring_run and not self._force_reset_monitoring_run:
  1330. self._initialize_run(with_id=self._run_id)
  1331. self._is_monitoring_run = True
  1332. else:
  1333. self._initialize_run()
  1334. self._force_reset_monitoring_run = False
  1335. def _ensure_at_least_run_without_monitoring(self):
  1336. if self._initial_run is not None:
  1337. self._use_initial_run()
  1338. else:
  1339. if not self._run:
  1340. self._initialize_run(
  1341. with_id=self._run_id,
  1342. capture_stdout=False,
  1343. capture_stderr=False,
  1344. capture_hardware_metrics=False,
  1345. capture_traceback=False,
  1346. )
  1347. self._is_monitoring_run = False
  1348. @property
  1349. def run(self):
  1350. if self._run is None:
  1351. self._ensure_at_least_run_without_monitoring()
  1352. return self._run
  1353. @property
  1354. def _metadata_namespace(self):
  1355. return self.run[self._base_namespace_path]
  1356. def _log_integration_version(self):
  1357. self.run[NeptuneCallback.integration_version_key] = version
  1358. def _log_trainer_parameters(self, args):
  1359. self._metadata_namespace[NeptuneCallback.trainer_parameters_key] = args.to_sanitized_dict()
  1360. def _log_model_parameters(self, model):
  1361. from neptune.utils import stringify_unsupported
  1362. if model and hasattr(model, "config") and model.config is not None:
  1363. self._metadata_namespace[NeptuneCallback.model_parameters_key] = stringify_unsupported(
  1364. model.config.to_dict()
  1365. )
  1366. def _log_hyper_param_search_parameters(self, state):
  1367. if state and hasattr(state, "trial_name"):
  1368. self._metadata_namespace[NeptuneCallback.trial_name_key] = state.trial_name
  1369. if state and hasattr(state, "trial_params") and state.trial_params is not None:
  1370. self._metadata_namespace[NeptuneCallback.trial_params_key] = state.trial_params
  1371. def _log_model_checkpoint(self, source_directory: str, checkpoint: str):
  1372. target_path = relative_path = os.path.join(source_directory, checkpoint)
  1373. if self._volatile_checkpoints_dir is not None:
  1374. consistent_checkpoint_path = os.path.join(self._volatile_checkpoints_dir, checkpoint)
  1375. try:
  1376. # Remove leading ../ from a relative path.
  1377. cpkt_path = relative_path.replace("..", "").lstrip(os.path.sep)
  1378. copy_path = os.path.join(consistent_checkpoint_path, cpkt_path)
  1379. shutil.copytree(relative_path, copy_path)
  1380. target_path = consistent_checkpoint_path
  1381. except OSError as e:
  1382. logger.warning(
  1383. f"NeptuneCallback was unable to made a copy of checkpoint due to I/O exception: '{e}'. "
  1384. "Could fail trying to upload."
  1385. )
  1386. self._metadata_namespace[self._target_checkpoints_namespace].upload_files(target_path)
  1387. if self._should_clean_recently_uploaded_checkpoint and self._recent_checkpoint_path is not None:
  1388. self._metadata_namespace[self._target_checkpoints_namespace].delete_files(self._recent_checkpoint_path)
  1389. self._recent_checkpoint_path = relative_path
  1390. def on_init_end(self, args, state, control, **kwargs):
  1391. self._volatile_checkpoints_dir = None
  1392. if self._log_checkpoints and args.save_total_limit is not None:
  1393. self._volatile_checkpoints_dir = tempfile.TemporaryDirectory().name
  1394. if self._log_checkpoints == "best" and not args.load_best_model_at_end:
  1395. raise ValueError("To save the best model checkpoint, the load_best_model_at_end argument must be enabled.")
  1396. def on_train_begin(self, args, state, control, model=None, **kwargs):
  1397. if not state.is_world_process_zero:
  1398. return
  1399. self._ensure_run_with_monitoring()
  1400. self._force_reset_monitoring_run = True
  1401. self._log_integration_version()
  1402. if self._log_parameters:
  1403. self._log_trainer_parameters(args)
  1404. self._log_model_parameters(model)
  1405. if state.is_hyper_param_search:
  1406. self._log_hyper_param_search_parameters(state)
  1407. def on_train_end(self, args, state, control, **kwargs):
  1408. self._stop_run_if_exists()
  1409. def __del__(self):
  1410. if self._volatile_checkpoints_dir is not None:
  1411. shutil.rmtree(self._volatile_checkpoints_dir, ignore_errors=True)
  1412. self._stop_run_if_exists()
  1413. def on_save(self, args, state, control, **kwargs):
  1414. if self._should_upload_checkpoint:
  1415. self._log_model_checkpoint(args.output_dir, f"checkpoint-{state.global_step}")
  1416. def on_evaluate(self, args, state, control, metrics=None, **kwargs):
  1417. if self._log_checkpoints == "best":
  1418. best_metric_name = args.metric_for_best_model
  1419. if not best_metric_name.startswith("eval_"):
  1420. best_metric_name = f"eval_{best_metric_name}"
  1421. metric_value = metrics.get(best_metric_name)
  1422. operator = np.greater if args.greater_is_better else np.less
  1423. self._should_upload_checkpoint = state.best_metric is None or operator(metric_value, state.best_metric)
  1424. @classmethod
  1425. def get_run(cls, trainer):
  1426. for callback in trainer.callback_handler.callbacks:
  1427. if isinstance(callback, cls):
  1428. return callback.run
  1429. raise Exception("The trainer doesn't have a NeptuneCallback configured.")
  1430. def on_log(self, args, state, control, logs: dict[str, float] | None = None, **kwargs):
  1431. if not state.is_world_process_zero:
  1432. return
  1433. if logs is not None:
  1434. for name, value in rewrite_logs(logs).items():
  1435. if isinstance(value, (int, float)):
  1436. if name in NeptuneCallback.flat_metrics:
  1437. self._metadata_namespace[name] = value
  1438. else:
  1439. self._metadata_namespace[name].log(value, step=state.global_step)
  1440. class CodeCarbonCallback(TrainerCallback):
  1441. """
  1442. A [`TrainerCallback`] that tracks the CO2 emission of training.
  1443. """
  1444. def __init__(self):
  1445. if not is_codecarbon_available():
  1446. raise RuntimeError(
  1447. "CodeCarbonCallback requires `codecarbon` to be installed. Run `pip install codecarbon`."
  1448. )
  1449. elif torch.version.hip:
  1450. raise RuntimeError(
  1451. "CodeCarbonCallback requires `codecarbon` package, which is not compatible with AMD ROCm (https://github.com/mlco2/codecarbon/pull/490). When using the Trainer, please specify the `report_to` argument (https://huggingface.co/docs/transformers/v4.39.3/en/main_classes/trainer#transformers.TrainingArguments.report_to) to disable CodeCarbonCallback."
  1452. )
  1453. import codecarbon
  1454. self._codecarbon = codecarbon
  1455. self.tracker = None
  1456. def on_init_end(self, args, state, control, **kwargs):
  1457. if self.tracker is None and state.is_local_process_zero:
  1458. # CodeCarbon will automatically handle environment variables for configuration
  1459. self.tracker = self._codecarbon.EmissionsTracker(output_dir=args.output_dir)
  1460. def on_train_begin(self, args, state, control, model=None, **kwargs):
  1461. if self.tracker and state.is_local_process_zero:
  1462. self.tracker.start()
  1463. def on_train_end(self, args, state, control, **kwargs):
  1464. if self.tracker and state.is_local_process_zero:
  1465. self.tracker.stop()
  1466. class ClearMLCallback(TrainerCallback):
  1467. """
  1468. A [`TrainerCallback`] that sends the logs to [ClearML](https://clear.ml/).
  1469. Environment:
  1470. - **CLEARML_PROJECT** (`str`, *optional*, defaults to `HuggingFace Transformers`):
  1471. ClearML project name.
  1472. - **CLEARML_TASK** (`str`, *optional*, defaults to `Trainer`):
  1473. ClearML task name.
  1474. - **CLEARML_LOG_MODEL** (`bool`, *optional*, defaults to `False`):
  1475. Whether to log models as artifacts during training.
  1476. """
  1477. log_suffix = ""
  1478. _hparams_section = "Transformers"
  1479. _model_config_section = "Model Configuration"
  1480. _ignore_hparams_overrides = "_ignore_hparams_ui_overrides_"
  1481. _ignoge_model_config_overrides = "_ignore_model_config_ui_overrides_"
  1482. _model_config_description = "The configuration of model number {}."
  1483. _model_config_description_note = (
  1484. "Note that, when cloning this task and running it remotely,"
  1485. " the configuration might be applied to another model instead of this one."
  1486. " To avoid this, initialize the task externally by calling `Task.init`"
  1487. " before the `ClearMLCallback` is instantiated."
  1488. )
  1489. _train_run_counter = 0
  1490. _model_connect_counter = 0
  1491. _task_created_in_callback = False
  1492. _should_close_on_train_end = None
  1493. def __init__(self):
  1494. if is_clearml_available():
  1495. import clearml
  1496. self._clearml = clearml
  1497. else:
  1498. raise RuntimeError("ClearMLCallback requires 'clearml' to be installed. Run `pip install clearml`.")
  1499. self._initialized = False
  1500. self._clearml_task = None
  1501. self._log_model = False
  1502. self._checkpoints_saved = []
  1503. def setup(self, args, state, model, processing_class, **kwargs):
  1504. if self._clearml is None:
  1505. return
  1506. if self._initialized:
  1507. return
  1508. ClearMLCallback._train_run_counter += 1
  1509. ClearMLCallback._model_connect_counter += 1
  1510. ClearMLCallback.log_suffix = (
  1511. "" if ClearMLCallback._train_run_counter == 1 else "_" + str(ClearMLCallback._train_run_counter)
  1512. )
  1513. if state.is_world_process_zero:
  1514. logger.info("Automatic ClearML logging enabled.")
  1515. if self._clearml_task is None:
  1516. if ClearMLCallback._should_close_on_train_end is None:
  1517. if not self._clearml.Task.running_locally() or self._clearml.Task.current_task():
  1518. ClearMLCallback._should_close_on_train_end = False
  1519. else:
  1520. ClearMLCallback._should_close_on_train_end = True
  1521. # This might happen when running inside of a pipeline, where the task is already initialized
  1522. # from outside of Hugging Face
  1523. if self._clearml.Task.running_locally() and self._clearml.Task.current_task():
  1524. self._clearml_task = self._clearml.Task.current_task()
  1525. self._log_model = os.getenv(
  1526. "CLEARML_LOG_MODEL",
  1527. "FALSE" if not ClearMLCallback._task_created_in_callback else "TRUE",
  1528. ).upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})
  1529. logger.info("External ClearML Task has been connected.")
  1530. else:
  1531. self._clearml_task = self._clearml.Task.init(
  1532. project_name=os.getenv("CLEARML_PROJECT", "HuggingFace Transformers"),
  1533. task_name=os.getenv("CLEARML_TASK", "Trainer"),
  1534. auto_connect_frameworks={"tensorboard": False, "pytorch": False},
  1535. output_uri=True,
  1536. )
  1537. self._log_model = os.getenv("CLEARML_LOG_MODEL", "TRUE").upper() in ENV_VARS_TRUE_VALUES.union(
  1538. {"TRUE"}
  1539. )
  1540. ClearMLCallback._task_created_in_callback = True
  1541. logger.info("ClearML Task has been initialized.")
  1542. self._initialized = True
  1543. suffixed_hparams_section = ClearMLCallback._hparams_section + ClearMLCallback.log_suffix
  1544. ignore_hparams_config_section = suffixed_hparams_section + "/" + ClearMLCallback._ignore_hparams_overrides
  1545. if self._clearml.Task.running_locally():
  1546. self._copy_training_args_as_hparams(args, suffixed_hparams_section)
  1547. self._clearml_task.set_parameter(
  1548. name=ignore_hparams_config_section,
  1549. value=True,
  1550. value_type=bool,
  1551. description=(
  1552. "If True, ignore Transformers hyperparameters overrides done in the UI/backend "
  1553. + "when running remotely. Otherwise, the overrides will be applied when running remotely"
  1554. ),
  1555. )
  1556. elif not self._clearml_task.get_parameter(ignore_hparams_config_section, default=True, cast=True):
  1557. self._clearml_task.connect(args, suffixed_hparams_section)
  1558. else:
  1559. self._copy_training_args_as_hparams(
  1560. args, ClearMLCallback._hparams_section + ClearMLCallback.log_suffix
  1561. )
  1562. if getattr(model, "config", None) is not None:
  1563. ignore_model_config_section = (
  1564. suffixed_hparams_section + "/" + ClearMLCallback._ignoge_model_config_overrides
  1565. )
  1566. configuration_object_description = ClearMLCallback._model_config_description.format(
  1567. ClearMLCallback._model_connect_counter
  1568. )
  1569. if ClearMLCallback._model_connect_counter != ClearMLCallback._train_run_counter:
  1570. configuration_object_description += " " + ClearMLCallback._model_config_description_note
  1571. if self._clearml.Task.running_locally():
  1572. self._clearml_task.set_parameter(
  1573. name=ignore_model_config_section,
  1574. value=True,
  1575. value_type=bool,
  1576. description=(
  1577. "If True, ignore Transformers model configuration overrides done in the UI/backend "
  1578. + "when running remotely. Otherwise, the overrides will be applied when running remotely"
  1579. ),
  1580. )
  1581. self._clearml_task.set_configuration_object(
  1582. name=ClearMLCallback._model_config_section + ClearMLCallback.log_suffix,
  1583. config_dict=model.config.to_dict(),
  1584. description=configuration_object_description,
  1585. )
  1586. elif not self._clearml_task.get_parameter(ignore_model_config_section, default=True, cast=True):
  1587. model.config = model.config.from_dict(
  1588. self._clearml_task.get_configuration_object_as_dict(
  1589. ClearMLCallback._model_config_section + ClearMLCallback.log_suffix
  1590. )
  1591. )
  1592. else:
  1593. self._clearml_task.set_configuration_object(
  1594. name=ClearMLCallback._model_config_section + ClearMLCallback.log_suffix,
  1595. config_dict=model.config.to_dict(),
  1596. description=configuration_object_description,
  1597. )
  1598. def on_train_begin(self, args, state, control, model=None, processing_class=None, **kwargs):
  1599. if self._clearml is None:
  1600. return
  1601. self._checkpoints_saved = []
  1602. if state.is_hyper_param_search:
  1603. self._initialized = False
  1604. if not self._initialized:
  1605. self.setup(args, state, model, processing_class, **kwargs)
  1606. def on_train_end(self, args, state, control, **kwargs):
  1607. if ClearMLCallback._should_close_on_train_end:
  1608. self._clearml_task.close()
  1609. ClearMLCallback._train_run_counter = 0
  1610. def on_log(self, args, state, control, model=None, processing_class=None, logs=None, **kwargs):
  1611. if self._clearml is None:
  1612. return
  1613. if not self._initialized:
  1614. self.setup(args, state, model, processing_class, **kwargs)
  1615. if state.is_world_process_zero:
  1616. eval_prefix = "eval_"
  1617. eval_prefix_len = len(eval_prefix)
  1618. test_prefix = "test_"
  1619. test_prefix_len = len(test_prefix)
  1620. single_value_scalars = [
  1621. "train_runtime",
  1622. "train_samples_per_second",
  1623. "train_steps_per_second",
  1624. "train_loss",
  1625. "total_flos",
  1626. "epoch",
  1627. ]
  1628. for k, v in logs.items():
  1629. if isinstance(v, (int, float)):
  1630. if k in single_value_scalars:
  1631. self._clearml_task.get_logger().report_single_value(
  1632. name=k + ClearMLCallback.log_suffix, value=v
  1633. )
  1634. elif k.startswith(eval_prefix):
  1635. self._clearml_task.get_logger().report_scalar(
  1636. title="eval" + ClearMLCallback.log_suffix,
  1637. series=k[eval_prefix_len:],
  1638. value=v,
  1639. iteration=state.global_step,
  1640. )
  1641. elif k.startswith(test_prefix):
  1642. self._clearml_task.get_logger().report_scalar(
  1643. title="test" + ClearMLCallback.log_suffix,
  1644. series=k[test_prefix_len:],
  1645. value=v,
  1646. iteration=state.global_step,
  1647. )
  1648. else:
  1649. self._clearml_task.get_logger().report_scalar(
  1650. title="train" + ClearMLCallback.log_suffix,
  1651. series=k,
  1652. value=v,
  1653. iteration=state.global_step,
  1654. )
  1655. else:
  1656. logger.warning(
  1657. "Trainer is attempting to log a value of "
  1658. f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
  1659. "This invocation of ClearML logger's report_scalar() "
  1660. "is incorrect so we dropped this attribute."
  1661. )
  1662. def on_save(self, args, state, control, **kwargs):
  1663. if self._log_model and self._clearml_task and state.is_world_process_zero:
  1664. ckpt_dir = f"checkpoint-{state.global_step}"
  1665. artifact_path = os.path.join(args.output_dir, ckpt_dir)
  1666. name = ckpt_dir + ClearMLCallback.log_suffix
  1667. logger.info(f"Logging checkpoint artifact `{name}`. This may take some time.")
  1668. output_model = self._clearml.OutputModel(task=self._clearml_task, name=name)
  1669. output_model.connect(task=self._clearml_task, name=name)
  1670. output_model.update_weights_package(
  1671. weights_path=artifact_path,
  1672. target_filename=ckpt_dir,
  1673. iteration=state.global_step,
  1674. auto_delete_file=False,
  1675. )
  1676. self._checkpoints_saved.append(output_model)
  1677. while args.save_total_limit and args.save_total_limit < len(self._checkpoints_saved):
  1678. try:
  1679. self._clearml.model.Model.remove(
  1680. self._checkpoints_saved[0],
  1681. delete_weights_file=True,
  1682. force=True,
  1683. raise_on_errors=True,
  1684. )
  1685. except Exception as e:
  1686. logger.warning(
  1687. f"Could not remove checkpoint `{self._checkpoints_saved[0].name}` after going over the `save_total_limit`. Error is: {e}"
  1688. )
  1689. break
  1690. self._checkpoints_saved = self._checkpoints_saved[1:]
  1691. def _copy_training_args_as_hparams(self, training_args, prefix):
  1692. as_dict = {
  1693. field.name: getattr(training_args, field.name)
  1694. for field in fields(training_args)
  1695. if field.init and not field.name.endswith("_token")
  1696. }
  1697. flat_dict = {str(k): v for k, v in self._clearml.utilities.proxy_object.flatten_dictionary(as_dict).items()}
  1698. self._clearml_task._arguments.copy_from_dict(flat_dict, prefix=prefix)
  1699. class FlyteCallback(TrainerCallback):
  1700. """A [`TrainerCallback`] that sends the logs to [Flyte](https://flyte.org/).
  1701. NOTE: This callback only works within a Flyte task.
  1702. Args:
  1703. save_log_history (`bool`, *optional*, defaults to `True`):
  1704. When set to True, the training logs are saved as a Flyte Deck.
  1705. sync_checkpoints (`bool`, *optional*, defaults to `True`):
  1706. When set to True, checkpoints are synced with Flyte and can be used to resume training in the case of an
  1707. interruption.
  1708. Example:
  1709. ```python
  1710. # Note: This example skips over some setup steps for brevity.
  1711. from flytekit import current_context, task
  1712. @task
  1713. def train_hf_transformer():
  1714. cp = current_context().checkpoint
  1715. trainer = Trainer(..., callbacks=[FlyteCallback()])
  1716. output = trainer.train(resume_from_checkpoint=cp.restore())
  1717. ```
  1718. """
  1719. def __init__(self, save_log_history: bool = True, sync_checkpoints: bool = True):
  1720. super().__init__()
  1721. if not is_flytekit_available():
  1722. raise ImportError("FlyteCallback requires flytekit to be installed. Run `pip install flytekit`.")
  1723. if not is_flyte_deck_standard_available() or not is_pandas_available():
  1724. logger.warning(
  1725. "Syncing log history requires both flytekitplugins-deck-standard and pandas to be installed. "
  1726. "Run `pip install flytekitplugins-deck-standard pandas` to enable this feature."
  1727. )
  1728. save_log_history = False
  1729. from flytekit import current_context
  1730. self.cp = current_context().checkpoint
  1731. self.save_log_history = save_log_history
  1732. self.sync_checkpoints = sync_checkpoints
  1733. def on_save(self, args, state, control, **kwargs):
  1734. if self.sync_checkpoints and state.is_world_process_zero:
  1735. ckpt_dir = f"checkpoint-{state.global_step}"
  1736. artifact_path = os.path.join(args.output_dir, ckpt_dir)
  1737. logger.info(f"Syncing checkpoint in {ckpt_dir} to Flyte. This may take time.")
  1738. self.cp.save(artifact_path)
  1739. def on_train_end(self, args, state, control, **kwargs):
  1740. if self.save_log_history:
  1741. import pandas as pd
  1742. from flytekit import Deck
  1743. from flytekitplugins.deck.renderer import TableRenderer
  1744. log_history_df = pd.DataFrame(state.log_history)
  1745. Deck("Log History", TableRenderer().to_html(log_history_df))
  1746. class DVCLiveCallback(TrainerCallback):
  1747. """
  1748. A [`TrainerCallback`] that sends the logs to [DVCLive](https://www.dvc.org/doc/dvclive).
  1749. Use the environment variables below in `setup` to configure the integration. To customize this callback beyond
  1750. those environment variables, see [here](https://dvc.org/doc/dvclive/ml-frameworks/huggingface).
  1751. Args:
  1752. live (`dvclive.Live`, *optional*, defaults to `None`):
  1753. Optional Live instance. If None, a new instance will be created using **kwargs.
  1754. log_model (Union[Literal["all"], bool], *optional*, defaults to `None`):
  1755. Whether to use `dvclive.Live.log_artifact()` to log checkpoints created by [`Trainer`]. If set to `True`,
  1756. the final checkpoint is logged at the end of training. If set to `"all"`, the entire
  1757. [`TrainingArguments`]'s `output_dir` is logged at each checkpoint.
  1758. """
  1759. def __init__(
  1760. self,
  1761. live: Any | None = None,
  1762. log_model: Literal["all"] | bool | None = None,
  1763. **kwargs,
  1764. ):
  1765. if not is_dvclive_available():
  1766. raise RuntimeError("DVCLiveCallback requires dvclive to be installed. Run `pip install dvclive`.")
  1767. from dvclive import Live
  1768. self._initialized = False
  1769. self.live = None
  1770. if isinstance(live, Live):
  1771. self.live = live
  1772. elif live is not None:
  1773. raise RuntimeError(f"Found class {live.__class__} for live, expected dvclive.Live")
  1774. self._log_model = log_model
  1775. if self._log_model is None:
  1776. log_model_env = os.getenv("HF_DVCLIVE_LOG_MODEL", "FALSE")
  1777. if log_model_env.upper() in ENV_VARS_TRUE_VALUES:
  1778. self._log_model = True
  1779. elif log_model_env.lower() == "all":
  1780. self._log_model = "all"
  1781. def setup(self, args, state, model):
  1782. """
  1783. Setup the optional DVCLive integration. To customize this callback beyond the environment variables below, see
  1784. [here](https://dvc.org/doc/dvclive/ml-frameworks/huggingface).
  1785. Environment:
  1786. - **HF_DVCLIVE_LOG_MODEL** (`str`, *optional*):
  1787. Whether to use `dvclive.Live.log_artifact()` to log checkpoints created by [`Trainer`]. If set to `True` or
  1788. *1*, the final checkpoint is logged at the end of training. If set to `all`, the entire
  1789. [`TrainingArguments`]'s `output_dir` is logged at each checkpoint.
  1790. """
  1791. from dvclive import Live
  1792. self._initialized = True
  1793. if state.is_world_process_zero:
  1794. if not self.live:
  1795. self.live = Live()
  1796. self.live.log_params(args.to_dict())
  1797. def on_train_begin(self, args, state, control, model=None, **kwargs):
  1798. if not self._initialized:
  1799. self.setup(args, state, model)
  1800. def on_log(self, args, state, control, model=None, logs=None, **kwargs):
  1801. if not self._initialized:
  1802. self.setup(args, state, model)
  1803. if state.is_world_process_zero:
  1804. from dvclive.plots import Metric
  1805. from dvclive.utils import standardize_metric_name
  1806. for key, value in logs.items():
  1807. if Metric.could_log(value):
  1808. self.live.log_metric(standardize_metric_name(key, "dvclive.huggingface"), value)
  1809. else:
  1810. logger.warning(
  1811. "Trainer is attempting to log a value of "
  1812. f'"{value}" of type {type(value)} for key "{key}" as a scalar. '
  1813. "This invocation of DVCLive's Live.log_metric() "
  1814. "is incorrect so we dropped this attribute."
  1815. )
  1816. self.live.next_step()
  1817. def on_save(self, args, state, control, **kwargs):
  1818. if self._log_model == "all" and self._initialized and state.is_world_process_zero:
  1819. self.live.log_artifact(args.output_dir)
  1820. def on_train_end(self, args, state, control, **kwargs):
  1821. if self._initialized and state.is_world_process_zero:
  1822. from transformers.trainer import Trainer
  1823. if self._log_model is True:
  1824. fake_trainer = Trainer(
  1825. args=args,
  1826. model=kwargs.get("model"),
  1827. processing_class=kwargs.get("processing_class"),
  1828. eval_dataset=["fake"],
  1829. )
  1830. name = "best" if args.load_best_model_at_end else "last"
  1831. output_dir = os.path.join(args.output_dir, name)
  1832. fake_trainer.save_model(output_dir)
  1833. self.live.log_artifact(output_dir, name=name, type="model", copy=True)
  1834. self.live.end()
  1835. class SwanLabCallback(TrainerCallback):
  1836. """
  1837. A [`TrainerCallback`] that logs metrics, media, model checkpoints to [SwanLab](https://swanlab.cn/).
  1838. """
  1839. def __init__(self):
  1840. if not is_swanlab_available():
  1841. raise RuntimeError("SwanLabCallback requires swanlab to be installed. Run `pip install swanlab`.")
  1842. import swanlab
  1843. self._swanlab = swanlab
  1844. self._initialized = False
  1845. self._log_model = os.getenv("SWANLAB_LOG_MODEL", None)
  1846. def setup(self, args, state, model, **kwargs):
  1847. """
  1848. Setup the optional SwanLab (*swanlab*) integration.
  1849. One can subclass and override this method to customize the setup if needed. Find more information
  1850. [here](https://docs.swanlab.cn/guide_cloud/integration/integration-huggingface-transformers.html).
  1851. You can also override the following environment variables. Find more information about environment
  1852. variables [here](https://docs.swanlab.cn/en/api/environment-variable.html#environment-variables)
  1853. Environment:
  1854. - **SWANLAB_API_KEY** (`str`, *optional*, defaults to `None`):
  1855. Cloud API Key. During login, this environment variable is checked first. If it doesn't exist, the system
  1856. checks if the user is already logged in. If not, the login process is initiated.
  1857. - If a string is passed to the login interface, this environment variable is ignored.
  1858. - If the user is already logged in, this environment variable takes precedence over locally stored
  1859. login information.
  1860. - **SWANLAB_PROJECT** (`str`, *optional*, defaults to `None`):
  1861. Set this to a custom string to store results in a different project. If not specified, the name of the current
  1862. running directory is used.
  1863. - **SWANLAB_LOG_DIR** (`str`, *optional*, defaults to `swanlog`):
  1864. This environment variable specifies the storage path for log files when running in local mode.
  1865. By default, logs are saved in a folder named swanlog under the working directory.
  1866. - **SWANLAB_MODE** (`Literal["local", "cloud", "disabled"]`, *optional*, defaults to `cloud`):
  1867. SwanLab's parsing mode, which involves callbacks registered by the operator. Currently, there are three modes:
  1868. local, cloud, and disabled. Note: Case-sensitive. Find more information
  1869. [here](https://docs.swanlab.cn/en/api/py-init.html#swanlab-init)
  1870. - **SWANLAB_LOG_MODEL** (`str`, *optional*, defaults to `None`):
  1871. SwanLab does not currently support the save mode functionality.This feature will be available in a future
  1872. release
  1873. - **SWANLAB_WEB_HOST** (`str`, *optional*, defaults to `None`):
  1874. Web address for the SwanLab cloud environment for private version (its free)
  1875. - **SWANLAB_API_HOST** (`str`, *optional*, defaults to `None`):
  1876. API address for the SwanLab cloud environment for private version (its free)
  1877. - **SWANLAB_RUN_ID** (`str`, *optional*, defaults to `None`):
  1878. Experiment ID to resume a previous run. Use with `SWANLAB_RESUME` to continue an existing experiment.
  1879. - **SWANLAB_RESUME** (`str`, *optional*, defaults to `None`):
  1880. Resume mode (`"must"`, `"allow"`, `"never"`). Defaults to `"allow"` when `resume_from_checkpoint` is used.
  1881. """
  1882. self._initialized = True
  1883. if state.is_world_process_zero:
  1884. logger.info('Automatic SwanLab logging enabled, to disable set os.environ["SWANLAB_MODE"] = "disabled"')
  1885. combined_dict = {**args.to_dict()}
  1886. if hasattr(model, "config") and model.config is not None:
  1887. model_config = model.config if isinstance(model.config, dict) else model.config.to_dict()
  1888. combined_dict = {**model_config, **combined_dict}
  1889. if hasattr(model, "peft_config") and model.peft_config is not None:
  1890. peft_config = model.peft_config
  1891. combined_dict = {"peft_config": peft_config, **combined_dict}
  1892. trial_name = state.trial_name
  1893. init_args = {}
  1894. if trial_name is not None and args.run_name is not None:
  1895. init_args["experiment_name"] = f"{args.run_name}-{trial_name}"
  1896. elif args.run_name is not None:
  1897. init_args["experiment_name"] = args.run_name
  1898. elif trial_name is not None:
  1899. init_args["experiment_name"] = trial_name
  1900. init_args["project"] = os.getenv("SWANLAB_PROJECT", None)
  1901. run_id = os.getenv("SWANLAB_RUN_ID", None)
  1902. if run_id is not None:
  1903. init_args["id"] = run_id
  1904. resume = os.getenv("SWANLAB_RESUME", None)
  1905. if resume is not None:
  1906. init_args["resume"] = resume
  1907. elif args.resume_from_checkpoint:
  1908. init_args["resume"] = "allow"
  1909. if self._swanlab.get_run() is None:
  1910. self._swanlab.init(
  1911. **init_args,
  1912. )
  1913. # show transformers logo!
  1914. self._swanlab.config["FRAMEWORK"] = "🤗transformers"
  1915. # add config parameters (run may have been created manually)
  1916. self._swanlab.config.update(combined_dict)
  1917. # add number of model parameters to swanlab config
  1918. try:
  1919. self._swanlab.config.update({"model_num_parameters": model.num_parameters()})
  1920. # get peft model parameters
  1921. if type(model).__name__ == "PeftModel" or type(model).__name__ == "PeftMixedModel":
  1922. trainable_params, all_param = model.get_nb_trainable_parameters()
  1923. self._swanlab.config.update({"peft_model_trainable_params": trainable_params})
  1924. self._swanlab.config.update({"peft_model_all_param": all_param})
  1925. except AttributeError:
  1926. logger.info("Could not log the number of model parameters in SwanLab due to an AttributeError.")
  1927. # log the initial model architecture to an artifact
  1928. if self._log_model is not None:
  1929. logger.warning(
  1930. "SwanLab does not currently support the save mode functionality. "
  1931. "This feature will be available in a future release."
  1932. )
  1933. badge_markdown = (
  1934. f'[<img src="https://raw.githubusercontent.com/SwanHubX/assets/main/badge1.svg"'
  1935. f' alt="Visualize in SwanLab" height="28'
  1936. f'0" height="32"/>]({self._swanlab.get_run().public.cloud.experiment_url})'
  1937. )
  1938. modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
  1939. def on_train_begin(self, args, state, control, model=None, **kwargs):
  1940. if not self._initialized:
  1941. self.setup(args, state, model, **kwargs)
  1942. def on_train_end(self, args, state, control, model=None, processing_class=None, **kwargs):
  1943. if self._log_model is not None and self._initialized and state.is_world_process_zero:
  1944. logger.warning(
  1945. "SwanLab does not currently support the save mode functionality. "
  1946. "This feature will be available in a future release."
  1947. )
  1948. def on_log(self, args, state, control, model=None, logs=None, **kwargs):
  1949. single_value_scalars = [
  1950. "train_runtime",
  1951. "train_samples_per_second",
  1952. "train_steps_per_second",
  1953. "train_loss",
  1954. "total_flos",
  1955. ]
  1956. if not self._initialized:
  1957. self.setup(args, state, model)
  1958. if state.is_world_process_zero:
  1959. for k, v in logs.items():
  1960. if k in single_value_scalars:
  1961. self._swanlab.log({f"single_value/{k}": v}, step=state.global_step)
  1962. non_scalar_logs = {k: v for k, v in logs.items() if k not in single_value_scalars}
  1963. non_scalar_logs = rewrite_logs(non_scalar_logs)
  1964. self._swanlab.log({**non_scalar_logs, "train/global_step": state.global_step}, step=state.global_step)
  1965. def on_save(self, args, state, control, **kwargs):
  1966. if self._log_model is not None and self._initialized and state.is_world_process_zero:
  1967. logger.warning(
  1968. "SwanLab does not currently support the save mode functionality. "
  1969. "This feature will be available in a future release."
  1970. )
  1971. def on_predict(self, args, state, control, metrics, **kwargs):
  1972. if not self._initialized:
  1973. self.setup(args, state, **kwargs)
  1974. if state.is_world_process_zero:
  1975. metrics = rewrite_logs(metrics)
  1976. self._swanlab.log(metrics)
  1977. class KubeflowCallback(TrainerCallback):
  1978. """
  1979. A [`TrainerCallback`] that reports training progress to [Kubeflow Trainer](https://github.com/kubeflow/trainer).
  1980. This callback is automatically registered when training inside a Kubeflow TrainJob with the
  1981. `TrainJobRuntimeStatus` feature gate enabled. The Kubeflow controller injects the required
  1982. environment variables into the training pod.
  1983. **Environment Variables (injected by controller):**
  1984. - `KUBEFLOW_TRAINER_SERVER_URL`: HTTPS endpoint for status updates
  1985. - `KUBEFLOW_TRAINER_SERVER_CA_CERT`: Path to CA certificate for TLS verification
  1986. - `KUBEFLOW_TRAINER_SERVER_TOKEN`: Path to service account token for authentication
  1987. **Reported Information:**
  1988. - Progress percentage (0-100%)
  1989. - Estimated time remaining (seconds)
  1990. - Training metrics (loss, learning_rate, etc.)
  1991. **Features:**
  1992. - Automatic throttling (max 1 update per 5 seconds) to avoid overwhelming the controller
  1993. - Token caching (5 minutes) to minimize file I/O
  1994. - Only rank 0 reports progress in distributed training
  1995. - Silent failures - network issues won't interrupt training
  1996. Can be disabled by setting environment variable `DISABLE_KUBEFLOW_INTEGRATION=TRUE`.
  1997. """
  1998. _MIN_UPDATE_INTERVAL = 5.0
  1999. _TOKEN_CACHE_DURATION = 300.0 # 5 minutes, aligned with SDK
  2000. _ENV_SERVER_URL = "KUBEFLOW_TRAINER_SERVER_URL"
  2001. _ENV_CA_CERT = "KUBEFLOW_TRAINER_SERVER_CA_CERT"
  2002. _ENV_TOKEN_PATH = "KUBEFLOW_TRAINER_SERVER_TOKEN"
  2003. def __init__(self):
  2004. if not is_kubeflow_available():
  2005. raise RuntimeError(
  2006. "KubeflowCallback requires KUBEFLOW_TRAINER_SERVER_URL environment variable to be set. "
  2007. "This is automatically set when running inside a Kubeflow TrainJob with TrainJobRuntimeStatus enabled."
  2008. )
  2009. self._initialized = False
  2010. self._metrics = {}
  2011. self._start_time = None
  2012. self._last_update_time = 0.0
  2013. self._cached_token = None
  2014. self._token_read_time = 0.0
  2015. self._ssl_context = None
  2016. self._ssl_context_initialized = False
  2017. logger.debug("[Kubeflow] Callback initialized")
  2018. def _get_ssl_context(self):
  2019. """Get cached SSL context for TLS verification."""
  2020. import ssl
  2021. if self._ssl_context_initialized:
  2022. return self._ssl_context
  2023. ca_file = os.environ.get(self._ENV_CA_CERT)
  2024. if ca_file:
  2025. try:
  2026. self._ssl_context = ssl.create_default_context(cafile=ca_file)
  2027. except Exception as e:
  2028. logger.warning(f"[Kubeflow] Failed to create SSL context with CA file {ca_file}: {e}")
  2029. self._ssl_context = None
  2030. self._ssl_context_initialized = True
  2031. return self._ssl_context
  2032. def _get_token(self):
  2033. """Get cached service account token."""
  2034. import time
  2035. now = time.monotonic()
  2036. if self._cached_token and (now - self._token_read_time) < self._TOKEN_CACHE_DURATION:
  2037. return self._cached_token
  2038. token_path = os.environ.get(self._ENV_TOKEN_PATH)
  2039. if not token_path or not os.path.exists(token_path):
  2040. logger.debug(f"[Kubeflow] Token file not found: {token_path}")
  2041. return None
  2042. try:
  2043. with open(token_path) as f:
  2044. self._cached_token = f.read().strip()
  2045. self._token_read_time = now
  2046. return self._cached_token
  2047. except OSError as e:
  2048. logger.debug(f"[Kubeflow] Failed to read token file: {e}")
  2049. return None
  2050. def _update_status(self, progress_percent=None, estimated_time_remaining=None, metrics=None, force=False):
  2051. """Send progress update to Kubeflow Trainer controller."""
  2052. import json
  2053. import time
  2054. import urllib.request
  2055. from datetime import datetime, timezone
  2056. try:
  2057. url = os.environ.get(self._ENV_SERVER_URL)
  2058. if not url:
  2059. return False
  2060. now = time.monotonic()
  2061. if not force and (now - self._last_update_time) < self._MIN_UPDATE_INTERVAL:
  2062. return False
  2063. self._last_update_time = now
  2064. token = self._get_token()
  2065. if not token:
  2066. return False
  2067. trainer_status = {"lastUpdatedTime": datetime.now(timezone.utc).isoformat()}
  2068. if progress_percent is not None:
  2069. trainer_status["progressPercentage"] = max(0, min(100, progress_percent))
  2070. if estimated_time_remaining is not None:
  2071. trainer_status["estimatedRemainingSeconds"] = max(0, int(estimated_time_remaining))
  2072. if metrics:
  2073. trainer_status["metrics"] = [{"name": str(k), "value": str(v)} for k, v in metrics.items()]
  2074. data = json.dumps({"trainerStatus": trainer_status}).encode("utf-8")
  2075. headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
  2076. req = urllib.request.Request(url, data=data, headers=headers, method="POST")
  2077. with urllib.request.urlopen(req, timeout=5, context=self._get_ssl_context()) as resp:
  2078. return resp.status == 200
  2079. except Exception as e:
  2080. logger.debug(f"[Kubeflow] Failed to update status: {e}")
  2081. return False
  2082. def on_train_begin(self, args, state, control, **kwargs):
  2083. if not state.is_world_process_zero:
  2084. return
  2085. import time
  2086. self._start_time = time.time()
  2087. self._metrics = {}
  2088. self._initialized = True
  2089. logger.debug(f"[Kubeflow] Training started, max_steps={state.max_steps}")
  2090. self._update_status(
  2091. progress_percent=0,
  2092. metrics={"total_steps": state.max_steps} if state.max_steps else None,
  2093. force=True,
  2094. )
  2095. def on_log(self, args, state, control, logs=None, **kwargs):
  2096. if not self._initialized or not state.is_world_process_zero or logs is None:
  2097. return
  2098. for key, value in logs.items():
  2099. if isinstance(value, (int, float)):
  2100. self._metrics[key] = value
  2101. def on_step_end(self, args, state, control, **kwargs):
  2102. if not self._initialized or not state.is_world_process_zero:
  2103. return
  2104. if not state.max_steps or state.max_steps <= 0:
  2105. return
  2106. import time
  2107. progress = int((state.global_step / state.max_steps) * 100)
  2108. # Cap at 99% until on_train_end reports 100% to indicate completion
  2109. progress = min(progress, 99)
  2110. eta_seconds = None
  2111. if self._start_time and state.global_step > 0:
  2112. elapsed = time.time() - self._start_time
  2113. avg_time_per_step = elapsed / state.global_step
  2114. remaining_steps = state.max_steps - state.global_step
  2115. eta_seconds = int(avg_time_per_step * remaining_steps)
  2116. metrics = {
  2117. **self._metrics,
  2118. "current_step": state.global_step,
  2119. "total_steps": state.max_steps,
  2120. }
  2121. if state.epoch is not None:
  2122. metrics["current_epoch"] = round(state.epoch, 2)
  2123. self._update_status(
  2124. progress_percent=progress,
  2125. estimated_time_remaining=eta_seconds,
  2126. metrics=metrics,
  2127. )
  2128. def on_train_end(self, args, state, control, **kwargs):
  2129. if not self._initialized or not state.is_world_process_zero:
  2130. return
  2131. logger.debug("[Kubeflow] Training completed")
  2132. self._update_status(
  2133. progress_percent=100,
  2134. estimated_time_remaining=0,
  2135. metrics=self._metrics,
  2136. force=True,
  2137. )
  2138. INTEGRATION_TO_CALLBACK = {
  2139. "azure_ml": AzureMLCallback,
  2140. "comet_ml": CometCallback,
  2141. "mlflow": MLflowCallback,
  2142. "neptune": NeptuneCallback,
  2143. "tensorboard": TensorBoardCallback,
  2144. "trackio": TrackioCallback,
  2145. "wandb": WandbCallback,
  2146. "codecarbon": CodeCarbonCallback,
  2147. "clearml": ClearMLCallback,
  2148. "dagshub": DagsHubCallback,
  2149. "flyte": FlyteCallback,
  2150. "dvclive": DVCLiveCallback,
  2151. "swanlab": SwanLabCallback,
  2152. "kubeflow": KubeflowCallback,
  2153. }
  2154. def get_reporting_integration_callbacks(report_to):
  2155. if report_to is None:
  2156. return []
  2157. if isinstance(report_to, str):
  2158. if report_to == "none":
  2159. return []
  2160. elif report_to == "all":
  2161. report_to = get_available_reporting_integrations()
  2162. else:
  2163. report_to = [report_to]
  2164. for integration in report_to:
  2165. if integration not in INTEGRATION_TO_CALLBACK:
  2166. raise ValueError(
  2167. f"{integration} is not supported, only {', '.join(INTEGRATION_TO_CALLBACK.keys())} are supported."
  2168. )
  2169. return [INTEGRATION_TO_CALLBACK[integration] for integration in report_to]