| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109411041114112411341144115411641174118411941204121412241234124412541264127412841294130413141324133413441354136413741384139414041414142414341444145414641474148414941504151415241534154415541564157415841594160416141624163416441654166416741684169417041714172417341744175417641774178417941804181418241834184418541864187418841894190419141924193419441954196419741984199420042014202420342044205420642074208420942104211421242134214421542164217421842194220422142224223422442254226422742284229423042314232423342344235423642374238423942404241424242434244424542464247424842494250425142524253425442554256425742584259426042614262426342644265426642674268426942704271427242734274427542764277427842794280428142824283428442854286428742884289429042914292429342944295429642974298429943004301430243034304430543064307430843094310431143124313431443154316431743184319432043214322432343244325432643274328432943304331433243334334433543364337433843394340434143424343434443454346434743484349435043514352435343544355435643574358435943604361436243634364436543664367436843694370437143724373437443754376437743784379438043814382438343844385438643874388438943904391439243934394439543964397439843994400440144024403440444054406440744084409441044114412 |
- # Copyright 2020-present the HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
- """
- import contextlib
- import functools
- import glob
- import inspect
- import json
- import math
- import os
- import random
- import shutil
- import sys
- import tempfile
- import time
- import warnings
- from collections.abc import Callable, Iterator, Mapping
- from functools import partial
- from pathlib import Path
- from typing import TYPE_CHECKING, Any
- # Integrations must be imported before ML frameworks:
- # ruff: isort: off
- from .integrations import (
- get_reporting_integration_callbacks,
- )
- # ruff: isort: on
- import numpy as np
- import safetensors.torch
- import torch
- import torch.distributed as dist
- from huggingface_hub import CommitInfo, ModelCard, create_repo, upload_folder
- from packaging import version
- from torch import nn
- from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler
- from . import __version__
- from .configuration_utils import PreTrainedConfig
- from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
- from .debug_utils import DebugOption, DebugUnderflowOverflow
- from .feature_extraction_sequence_utils import SequenceFeatureExtractor
- from .feature_extraction_utils import FeatureExtractionMixin
- from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
- from .image_processing_utils import BaseImageProcessor
- from .integrations.deepspeed import (
- deepspeed_init,
- deepspeed_load_checkpoint,
- deepspeed_sp_compute_loss,
- is_deepspeed_available,
- propagate_args_to_deepspeed,
- )
- from .integrations.fsdp import get_fsdp_ckpt_kwargs, update_fsdp_plugin_peft
- from .integrations.liger import apply_liger_kernel
- from .integrations.neftune import activate_neftune, deactivate_neftune
- from .integrations.peft import MIN_PEFT_VERSION
- from .integrations.tpu import save_tpu_checkpoint, tpu_spmd_dataloader, wrap_model_xla_fsdp
- from .modelcard import TrainingSummary
- from .modeling_utils import PreTrainedModel, unwrap_model
- from .models.auto.modeling_auto import (
- MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
- MODEL_MAPPING_NAMES,
- )
- from .optimization import GreedyLR, get_scheduler
- from .processing_utils import ProcessorMixin
- from .tokenization_utils_base import PreTrainedTokenizerBase
- from .trainer_callback import (
- CallbackHandler,
- DefaultFlowCallback,
- ExportableState,
- PrinterCallback,
- ProgressCallback,
- TrainerCallback,
- TrainerControl,
- TrainerState,
- )
- from .trainer_optimizer import (
- _OPTIMIZER_HANDLERS,
- OptimizerContext,
- _parse_optim_args,
- is_optimizer_factory,
- )
- from .trainer_pt_utils import (
- EvalLoopContainer,
- IterableDatasetShard,
- LabelSmoother,
- LengthGroupedSampler,
- distributed_broadcast_scalars,
- find_batch_size,
- get_model_param_count,
- get_parameter_names,
- is_attention_mask_causal,
- nested_detach,
- nested_gather,
- reissue_pt_warnings,
- remove_dummy_checkpoint,
- safe_globals,
- set_rng_state_for_device,
- )
- from .trainer_utils import (
- PREFIX_CHECKPOINT_DIR,
- BestRun,
- EvalLoopOutput,
- EvalPrediction,
- HPSearchBackend,
- HubStrategy,
- PredictionOutput,
- RemoveColumnsCollator,
- SaveStrategy,
- TrainerMemoryTracker,
- TrainOutput,
- _is_peft_model,
- align_special_tokens,
- compare_trainer_and_checkpoint_args,
- default_compute_objective,
- denumpify_detensorize,
- enable_full_determinism,
- find_executable_batch_size,
- get_last_checkpoint,
- has_length,
- load_sharded_checkpoint,
- number_of_arguments,
- rotate_checkpoints,
- seed_worker,
- set_seed,
- sort_checkpoints,
- speed_metrics,
- suppress_progress_bars,
- unwrap_peft_model,
- validate_quantization_for_training,
- )
- from .training_args import OptimizerNames, ParallelMode, TrainingArguments
- from .utils import (
- ADAPTER_CONFIG_NAME,
- ADAPTER_SAFE_WEIGHTS_NAME,
- ADAPTER_WEIGHTS_NAME,
- CONFIG_NAME,
- GENERATION_CONFIG_NAME,
- SAFE_WEIGHTS_INDEX_NAME,
- SAFE_WEIGHTS_NAME,
- WEIGHTS_INDEX_NAME,
- WEIGHTS_NAME,
- XLA_FSDPV2_MIN_VERSION,
- PushInProgress,
- can_return_loss,
- check_torch_load_is_safe,
- find_labels,
- is_accelerate_available,
- is_datasets_available,
- is_in_notebook,
- is_peft_available,
- is_sagemaker_dp_enabled,
- is_sagemaker_mp_enabled,
- is_torch_hpu_available,
- is_torch_mlu_available,
- is_torch_musa_available,
- is_torch_npu_available,
- is_torch_xla_available,
- logging,
- )
- from .utils.import_utils import requires
- from .utils.quantization_config import QuantizationMethod
- DEFAULT_CALLBACKS = [DefaultFlowCallback]
- DEFAULT_PROGRESS_CALLBACK = ProgressCallback
- if is_in_notebook():
- from .utils.notebook import NotebookProgressCallback
- DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
- if is_datasets_available():
- import datasets
- if is_torch_xla_available():
- import torch_xla.core.xla_model as xm
- import torch_xla.debug.metrics as met
- import torch_xla.runtime as xr
- from torch_xla import __version__ as XLA_VERSION
- IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION)
- if IS_XLA_FSDPV2_POST_2_2:
- import torch_xla.distributed.spmd as xs
- else:
- IS_XLA_FSDPV2_POST_2_2 = False
- if is_sagemaker_mp_enabled():
- import smdistributed.modelparallel.torch as smp
- from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_nested_concat
- if is_peft_available():
- from peft import PeftModel
- if is_accelerate_available():
- from accelerate import Accelerator, skip_first_batches
- from accelerate.state import AcceleratorState
- from accelerate.utils import (
- DataLoaderConfiguration,
- DistributedDataParallelKwargs,
- DistributedType,
- GradientAccumulationPlugin,
- load_fsdp_model,
- load_fsdp_optimizer,
- release_memory,
- save_fsdp_model,
- save_fsdp_optimizer,
- )
- from accelerate.utils.memory import clear_device_cache
- if is_deepspeed_available():
- from accelerate.utils import DeepSpeedSchedulerWrapper
- if TYPE_CHECKING:
- import optuna
- logger = logging.get_logger(__name__)
- # Name of the files used for checkpointing
- TRAINING_ARGS_NAME = "training_args.bin"
- TRAINER_STATE_NAME = "trainer_state.json"
- OPTIMIZER_NAME = "optimizer.pt"
- SCALER_NAME = "scaler.pt"
- OPTIMIZER_NAME_BIN = "optimizer.bin"
- SCHEDULER_NAME = "scheduler.pt"
- FSDP_MODEL_NAME = "pytorch_model_fsdp"
- @requires(
- backends=(
- "torch",
- "accelerate",
- )
- )
- class Trainer:
- """
- Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
- Args:
- model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*):
- The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed.
- <Tip>
- [`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use
- your own models defined as `torch.nn.Module` as long as they work the same way as the 🤗 Transformers
- models.
- </Tip>
- args ([`TrainingArguments`], *optional*):
- The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the
- `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided.
- data_collator (`DataCollator`, *optional*):
- The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will
- default to [`default_data_collator`] if no `processing_class` is provided, an instance of
- [`DataCollatorWithPadding`] otherwise if the processing_class is a feature extractor or tokenizer.
- train_dataset (`torch.utils.data.Dataset` | `torch.utils.data.IterableDataset` | `datasets.Dataset`, *optional*):
- The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the
- `model.forward()` method are automatically removed.
- Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a
- distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a
- `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will
- manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally
- sets the seed of the RNGs used.
- eval_dataset (`torch.utils.data.Dataset` | dict[str, `torch.utils.data.Dataset`] | `datasets.Dataset`, *optional*):
- The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the
- `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each
- dataset prepending the dictionary key to the metric name.
- processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
- Processing class used to process the data. If provided, will be used to automatically process the inputs
- for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
- reuse the fine-tuned model.
- model_init (`Callable[[], PreTrainedModel]`, *optional*):
- A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start
- from a new instance of the model as given by this function.
- The function may have zero argument, or a single one containing the optuna/Ray Tune trial object, to
- be able to choose different architectures according to hyperparameters (such as layer count, sizes of
- inner layers, dropout probabilities etc).
- compute_loss_func (`Callable`, *optional*):
- A function that accepts the raw model outputs, labels, and the number of items in the entire accumulated
- batch (batch_size * gradient_accumulation_steps) and returns the loss. For example, see the default [loss function](https://github.com/huggingface/transformers/blob/052e652d6d53c2b26ffde87e039b723949a53493/src/transformers/trainer.py#L3618) used by [`Trainer`].
- compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
- The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
- a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to
- `True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered
- after the last eval batch to signal that the function needs to calculate and return the global summary
- statistics rather than accumulating the batch-level statistics
- callbacks (List of [`TrainerCallback`], *optional*):
- A list of callbacks to customize the training loop. Will add those to the list of default callbacks
- detailed in [here](callback).
- If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
- A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
- model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
- optimizer_cls_and_kwargs (`tuple[Type[torch.optim.Optimizer], dict[str, Any]]`, *optional*):
- A tuple containing the optimizer class and keyword arguments to use.
- Overrides `optim` and `optim_args` in `args`. Incompatible with the `optimizers` argument.
- Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before initializing the Trainer.
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
- A function that preprocess the logits right before caching them at each evaluation step. Must take two
- tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
- by this function will be reflected in the predictions received by `compute_metrics`.
- Note that the labels (second parameter) will be `None` if the dataset does not have them.
- Important attributes:
- - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`]
- subclass.
- - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the
- original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`,
- the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner
- model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`.
- - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
- data parallelism, this means some of the model layers are split on different GPUs).
- - **place_model_on_device** -- Whether or not to automatically place the model on the device. Defaults to
- `True` unless model parallel, DeepSpeed, FSDP, full fp16/bf16 eval, or SageMaker MP is active. Can be
- overridden by subclassing `TrainingArguments` and overriding the `place_model_on_device` property.
- - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while
- in `train`)
- """
- # Those methods are not used in Trainer itself but are available as methods for external use.
- from .trainer_pt_utils import (
- get_learning_rates,
- get_num_trainable_parameters,
- get_optimizer_group,
- log_metrics,
- metrics_format,
- save_metrics,
- save_state,
- )
- # ---- Initialization & Validation ----
- def __init__(
- self,
- model: PreTrainedModel | nn.Module | None = None,
- args: TrainingArguments | None = None,
- data_collator: DataCollator | None = None,
- train_dataset: "Dataset | IterableDataset | datasets.Dataset | None" = None,
- eval_dataset: "Dataset | dict[str, Dataset] | datasets.Dataset | None" = None,
- processing_class: PreTrainedTokenizerBase
- | BaseImageProcessor
- | FeatureExtractionMixin
- | ProcessorMixin
- | None = None,
- model_init: Callable[..., PreTrainedModel] | None = None,
- compute_loss_func: Callable | None = None,
- compute_metrics: Callable[[EvalPrediction], dict] | None = None,
- callbacks: list[TrainerCallback] | None = None,
- optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None),
- optimizer_cls_and_kwargs: tuple[type[torch.optim.Optimizer], dict[str, Any]] | None = None,
- preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
- ):
- # Init flow:
- # 1. Args & seed – defaults, determinism
- # 2. Accelerator & logging – accelerator, memory tracker, log level, device setup
- # 3. Model resolution – model / model_init, Liger Kernel, quantization checks
- # 4. Distributed strategy – model-parallel, FSDP, SageMaker MP flags
- # 5. Device placement – move model to device, model wrapping
- # 6. Model introspection – loss kwargs, label names, label smoother
- # 7. Store init arguments – data, callables, optimizer, scheduler, validation
- # 8. Callbacks – reporting integrations, JIT checkpoint, progress bar
- # 9. Hub & output – repo init, output directory
- # 10. Training state – TrainerState, TrainerControl, internal bookkeeping
- # 11. Finalize – use_cache, XLA FSDPv2 mesh, memory tracker stop
- # ---- 1. Args & seed --------------------------------------------------------
- if args is None:
- output_dir = "tmp_trainer"
- logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
- args = TrainingArguments(output_dir=output_dir)
- self.args = args
- # Seed must be set before instantiating the model when using model_init
- enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
- # ---- 2. Accelerator & logging ----------------------------------------------
- # `create_accelerator_and_postprocess` reads self.model and self.args,
- # and may set self.deepspeed — store temporary refs before calling it.
- self.deepspeed = None
- self.model = model
- self.create_accelerator_and_postprocess()
- self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
- self._memory_tracker.start()
- log_level = args.get_process_log_level()
- logging.set_verbosity(log_level)
- args._setup_devices # force device and distributed setup init explicitly
- # ---- 3. Model resolution ----------------------------------------------------
- if model is None:
- if model_init is not None:
- self.model_init = model_init
- model = self.call_model_init()
- else:
- raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
- else:
- if model_init is not None:
- raise ValueError("`Trainer` requires either a `model` or `model_init` argument, but not both.")
- self.model_init = model_init
- if model.__class__.__name__ in MODEL_MAPPING_NAMES:
- raise ValueError(
- f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only "
- "computes hidden states and does not accept any labels. You should choose a model with a head "
- "suitable for your task like any of the `AutoModelForXxx` listed at "
- "https://huggingface.co/docs/transformers/model_doc/auto"
- )
- validate_quantization_for_training(model)
- # ---- 4. Distributed strategy ------------------------------------------------
- self.is_model_parallel = False
- if getattr(model, "hf_device_map", None) is not None:
- devices = [device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]]
- if len(devices) > 1:
- self.is_model_parallel = True
- elif len(devices) == 1:
- self.is_model_parallel = self.args.device != torch.device(devices[0])
- self.is_fsdp_xla_enabled = args.fsdp_config["xla"]
- if len(args.fsdp) > 0:
- if self.is_deepspeed_enabled:
- raise ValueError(
- "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags."
- )
- if not args.fsdp_config["xla"] and args.parallel_mode != ParallelMode.DISTRIBUTED:
- raise ValueError("Using fsdp only works in distributed training.")
- # Postpone switching model to cuda when MP, DeepSpeed, full bf16/fp16 eval, or FSDP
- if args.place_model_on_device is not None:
- self.place_model_on_device = args.place_model_on_device
- elif (
- self.is_model_parallel
- or self.is_deepspeed_enabled
- or (args.fp16_full_eval or args.bf16_full_eval)
- or self.is_fsdp_xla_enabled
- or self.is_fsdp_enabled
- or is_sagemaker_mp_enabled()
- ):
- self.place_model_on_device = False
- else:
- self.place_model_on_device = True
- # ---- 5. Device placement ----------------------------------------------------
- # Bnb Quantized models don't support `.to` operation.
- if (
- self.place_model_on_device
- and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES
- ):
- self._move_model_to_device(model, args.device)
- # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
- if self.is_model_parallel:
- self.args._n_gpu = 1
- # `self.model is self.model_wrapped` is used later to check if it's wrapped
- self.model_wrapped = model
- self.model = model
- # ---- 6. Model introspection -------------------------------------------------
- unwrapped_model = unwrap_peft_model(self.accelerator.unwrap_model(model))
- if hasattr(unwrapped_model, "accepts_loss_kwargs"):
- self.model_accepts_loss_kwargs = unwrapped_model.accepts_loss_kwargs
- else:
- forward_params = inspect.signature(unwrapped_model.forward).parameters
- self.model_accepts_loss_kwargs = any(
- k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values()
- )
- # Sequence Parallelism computes its own good_tokens count
- pc = getattr(self.accelerator, "parallelism_config", None)
- if pc is not None and pc.sp_backend == "deepspeed" and pc.sp_enabled:
- self.model_accepts_loss_kwargs = False
- model_to_inspect = unwrap_peft_model(self.model)
- default_label_names = find_labels(model_to_inspect.__class__)
- self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
- self.can_return_loss = can_return_loss(model_to_inspect.__class__)
- if self.args.label_smoothing_factor != 0:
- if getattr(self.model.config, "problem_type", None) == "multi_label_classification":
- warnings.warn(
- "Label smoothing is not compatible with multi-label classification. "
- "Disabling label smoothing for this training run.",
- UserWarning,
- )
- self.label_smoother = None
- else:
- self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
- else:
- self.label_smoother = None
- # ---- 7. Store init arguments ------------------------------------------------
- # Data
- default_collator = (
- DataCollatorWithPadding(processing_class)
- if processing_class is not None
- and isinstance(processing_class, (PreTrainedTokenizerBase, SequenceFeatureExtractor))
- else default_data_collator
- )
- self.data_collator = data_collator if data_collator is not None else default_collator
- self.train_dataset = train_dataset
- self.eval_dataset = eval_dataset
- self.processing_class = processing_class
- self.neftune_noise_alpha = args.neftune_noise_alpha
- # Callables
- self.compute_loss_func = compute_loss_func
- self.compute_metrics = compute_metrics
- self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
- # Optimizer & scheduler
- self.optimizer, self.lr_scheduler = optimizers
- self.optimizer_cls_and_kwargs = optimizer_cls_and_kwargs
- self._validate_args()
- # ---- 8. Callbacks -----------------------------------------------------------
- default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
- if self.args.enable_jit_checkpoint:
- from .trainer_jit_checkpoint import JITCheckpointCallback
- jit_callback = JITCheckpointCallback()
- default_callbacks = default_callbacks + [jit_callback]
- jit_callback.set_trainer(self)
- callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
- self.callback_handler = CallbackHandler(
- callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
- )
- self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
- # ---- 9. Hub & output ---------------------------------------------------------
- self.hub_model_id = None # Set by init_hf_repo() when push_to_hub is enabled
- if self.args.push_to_hub:
- self.init_hf_repo()
- if self.args.should_save:
- os.makedirs(self.args.output_dir, exist_ok=True)
- # ---- 10. Training state -----------------------------------------------------
- self.control = TrainerControl()
- self.state = TrainerState(
- is_local_process_zero=self.is_local_process_zero(),
- is_world_process_zero=self.is_world_process_zero(),
- stateful_callbacks=[
- cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
- ],
- )
- self.is_in_train = False # True between train() entry and exit
- self.hp_name = None # Set by hyperparameter_search() to label the trial
- self.hp_search_backend = None # Set by hyperparameter_search() (optuna / ray / wandb)
- # Per-process FLOP counter; accumulated into self.state.total_flos then reset
- self.current_flos = 0
- # Set True by _setup_loggers() on first call to self.log()
- self._loggers_initialized = False
- # Lazily filled by _set_signature_columns_if_needed(); caches model.forward param names
- self._signature_columns = None
- # Effective batch size; may be reduced by find_executable_batch_size
- self._train_batch_size = args.train_batch_size
- # Guards one-time LR scheduler creation in create_optimizer_and_scheduler
- self._created_lr_scheduler = False
- self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
- # ---- 11. Finalize -----------------------------------------------------------
- if getattr(self.model, "config", None) is not None:
- self.model.config.use_cache = self.args.use_cache
- self.is_fsdp_xla_v2_enabled = args.fsdp_config.get("xla_fsdp_v2", False)
- if self.is_fsdp_xla_v2_enabled:
- if not IS_XLA_FSDPV2_POST_2_2:
- raise ValueError("FSDPv2 requires `torch_xla` 2.2 or higher.")
- num_devices = xr.global_runtime_device_count()
- xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor")))
- self.is_fsdp_xla_v1_enabled = self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled
- self._memory_tracker.stop_and_update_metrics()
- def _validate_args(self) -> None:
- """Validate constructor arguments and fail fast on incompatible combinations."""
- args = self.args
- # --- SageMaker Model Parallel mixed-precision validation ---
- if is_sagemaker_mp_enabled():
- if args.bf16:
- raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ")
- if args.fp16 != smp.state.cfg.fp16:
- logger.warning(
- f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
- f"but FP16 provided in trainer argument is {args.fp16}, "
- f"setting to {smp.state.cfg.fp16}"
- )
- args.fp16 = smp.state.cfg.fp16
- # --- Training-argument validations ---
- if args.batch_eval_metrics and self.compute_metrics is not None:
- if "compute_result" not in inspect.signature(self.compute_metrics).parameters:
- raise ValueError(
- "When using `batch_eval_metrics`, your `compute_metrics` function must take a `compute_result`"
- " boolean argument which will be triggered after the last batch of the eval set to signal that the"
- " summary statistics should be returned by the function."
- )
- if args.eval_strategy is not None and args.eval_strategy != "no" and self.eval_dataset is None:
- raise ValueError(
- f"You have set `args.eval_strategy` to {args.eval_strategy} but you didn't pass an `eval_dataset` to `Trainer`. Either set `args.eval_strategy` to `no` or pass an `eval_dataset`. "
- )
- if args.save_strategy == SaveStrategy.BEST or args.load_best_model_at_end:
- if args.metric_for_best_model is None:
- raise ValueError(
- "`args.metric_for_best_model` must be provided when using 'best' save_strategy or if `args.load_best_model_at_end` is set to `True`."
- )
- # --- Optimizer validations ---
- if self.optimizer_cls_and_kwargs is not None and self.optimizer is not None:
- raise RuntimeError("Passing both `optimizers` and `optimizer_cls_and_kwargs` arguments is incompatible.")
- if self.model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
- raise RuntimeError(
- "Passing a `model_init` is incompatible with providing the `optimizers` argument. "
- "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
- )
- if is_torch_xla_available() and self.optimizer is not None:
- for param in self.model.parameters():
- model_device = param.device
- break
- for param_group in self.optimizer.param_groups:
- if len(param_group["params"]) > 0:
- optimizer_device = param_group["params"][0].device
- break
- if model_device != optimizer_device:
- raise ValueError(
- "The model and the optimizer parameters are not on the same device, which probably means you"
- " created an optimizer around your model **before** putting on the device and passing it to the"
- " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
- " `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
- )
- if (self.is_fsdp_xla_enabled or self.is_fsdp_enabled) and (
- self.optimizer is not None or self.lr_scheduler is not None
- ):
- raise RuntimeError(
- "Passing `optimizers` is not allowed if PyTorch FSDP is enabled. "
- "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
- )
- # --- Dataset validations ---
- if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
- raise TypeError("The `data_collator` should be a simple callable (function, class with `__call__`).")
- if args.max_steps > 0 and args.num_train_epochs > 0:
- logger.info("max_steps is given, it will override any value given in num_train_epochs")
- if self.train_dataset is not None and not has_length(self.train_dataset) and args.max_steps <= 0:
- raise ValueError(
- "The train_dataset does not implement __len__, max_steps has to be specified. "
- "The number of steps needs to be known in advance for the learning rate scheduler."
- )
- if self.train_dataset is not None and isinstance(self.train_dataset, torch.utils.data.IterableDataset):
- logger.info(
- f"The `train_sampling_strategy='{args.train_sampling_strategy}'` option is ignored when using an `IterableDataset`. "
- "Samplers cannot be used with IterableDataset as they require indexed access to the dataset."
- )
- def _build_accelerator_args(self, **kwargs) -> dict[str, Any]:
- """Helper method to build accelerator-specific keyword arguments."""
- args = {
- "mixed_precision": self.args.mixed_precision,
- "deepspeed_plugin": self.args.deepspeed_plugin,
- }
- args.update(kwargs)
- if self.args.ddp_find_unused_parameters is not None:
- find_unused = self.args.ddp_find_unused_parameters
- elif isinstance(self.model, PreTrainedModel):
- # find_unused_parameters breaks checkpointing as per
- # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
- find_unused = not (self.model.is_gradient_checkpointing or self.args.gradient_checkpointing)
- else:
- find_unused = True
- ddp_kwargs = {"find_unused_parameters": find_unused}
- if self.args.ddp_bucket_cap_mb is not None:
- ddp_kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
- if self.args.ddp_broadcast_buffers is not None:
- ddp_kwargs["broadcast_buffers"] = self.args.ddp_broadcast_buffers
- args["kwargs_handlers"] = [DistributedDataParallelKwargs(**ddp_kwargs)]
- # We defer compatibility checks to accelerator
- if self.args.parallelism_config is not None:
- min_accelerate_version = "1.12.0"
- if not is_accelerate_available(min_accelerate_version):
- raise ImportError(
- f"ParallelismConfig requires accelerate>={min_accelerate_version}). Please upgrade accelerate to use this feature."
- )
- args["parallelism_config"] = self.args.parallelism_config
- if getattr(self.model, "tp_size", None) is not None and self.model.tp_size > 1:
- if self.args.parallelism_config is None:
- if is_accelerate_available("1.12.0"):
- if self.args.parallelism_config is None:
- from accelerate import ParallelismConfig
- args["parallelism_config"] = ParallelismConfig(tp_size=self.model.tp_size)
- else:
- raise ValueError("Requires accelerate>1.12.0 to use Tensor Parallelism.")
- elif args["parallelism_config"].tp_size != self.model.tp_size:
- args["parallelism_config"].tp_size = self.model.tp_size
- if is_accelerate_available("1.2.0"):
- # it we don't have the correct version, we will rely on env var instead that were set in TrainingArguments
- from accelerate.utils import TorchDynamoPlugin
- dynamo_plugin = TorchDynamoPlugin(
- backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode
- )
- args["dynamo_plugin"] = dynamo_plugin
- return args
- def create_accelerator_and_postprocess(self) -> None:
- """Create the accelerator and perform post-creation setup (FSDP, DeepSpeed, etc.)."""
- # We explicitly don't rely on the `Accelerator` to do gradient accumulation
- grad_acc_kwargs = {}
- if self.args.accelerator_config.gradient_accumulation_kwargs is not None:
- grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs
- # check if num_steps is attempted to be passed in gradient_accumulation_kwargs
- if "num_steps" in grad_acc_kwargs:
- if self.args.gradient_accumulation_steps > 1:
- # raise because we do not know which setting is intended.
- raise ValueError(
- "The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`"
- "If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`."
- )
- else:
- self.args.gradient_accumulation_steps = grad_acc_kwargs["num_steps"]
- # The Trainer handles GAS itself, so GAS=1 in Accelerate to avoid any double-division
- grad_acc_kwargs["num_steps"] = 1
- # Just making sure that gradient_state have the correct values passed.
- # We don't rely on `accumulate` from accelerate to set sync_gradients in gradient_state.
- # Rather, we do it ourselves by setting self.accelerator.gradient_state._set_sync_gradients.
- gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
- accelerator_config = self.args.accelerator_config.to_dict()
- # Extract dataloader config params from accelerator config
- dataloader_params = ["split_batches", "dispatch_batches", "even_batches", "use_seedable_sampler"]
- dataloader_config = DataLoaderConfiguration(
- **{param: accelerator_config.pop(param) for param in dataloader_params}
- )
- dataloader_config.data_seed = self.args.data_seed
- non_blocking = accelerator_config.pop("non_blocking")
- if non_blocking and not self.args.dataloader_pin_memory:
- logger.warning(
- "`non_blocking` is enabled but `dataloader_pin_memory` is not. For the best performance, it's recommended to enable both."
- )
- dataloader_config.non_blocking = non_blocking
- # this would have been updated above, no need for it anymore
- accelerator_config.pop("gradient_accumulation_kwargs")
- fsdp_plugin = None
- if self.args.fsdp_plugin_args is not None:
- from accelerate.utils import FullyShardedDataParallelPlugin
- fsdp_plugin = FullyShardedDataParallelPlugin(**self.args.fsdp_plugin_args)
- args = self._build_accelerator_args(
- dataloader_config=dataloader_config,
- fsdp_plugin=fsdp_plugin,
- gradient_accumulation_plugin=gradient_accumulation_plugin,
- )
- # create accelerator object
- self.accelerator = Accelerator(**args)
- # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
- self.gather_function = self.accelerator.gather_for_metrics
- if "use_gather_object" in inspect.signature(self.gather_function).parameters:
- self.gather_function = functools.partial(
- self.gather_function, use_gather_object=self.args.eval_use_gather_object
- )
- # deepspeed and accelerate flags covering both trainer args and accelerate launcher
- self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
- self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
- # post accelerator creation setup
- if self.is_fsdp_enabled:
- fsdp_plugin = self.accelerator.state.fsdp_plugin
- for param in ["limit_all_gathers", "activation_checkpointing"]:
- setattr(fsdp_plugin, param, self.args.fsdp_config.get(param, getattr(fsdp_plugin, param)))
- if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing:
- raise ValueError(
- "The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg "
- "can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic "
- "when using FSDP."
- )
- if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None:
- propagate_args_to_deepspeed(self.accelerator, self.args)
- # `save_only_model` can't be used with DeepSpeed/FSDP along with `load_best_model_at_end`
- if (
- self.args.save_only_model
- and (self.is_deepspeed_enabled or self.is_fsdp_enabled)
- and self.args.load_best_model_at_end
- ):
- wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP"
- raise ValueError(f"{wrapper} can't be used with `save_only_model` along with `load_best_model_at_end`.")
- # `auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3
- if (
- self.is_deepspeed_enabled
- and self.accelerator.state.deepspeed_plugin.zero_stage == 3
- and self.args.auto_find_batch_size
- ):
- raise ValueError(
- "`auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3. Please consider using Zero-2, Zero-1, or FSDP"
- )
- if (
- self.args.save_only_model
- and self.is_fsdp_enabled
- and "SHARDED_STATE_DICT" in str(self.accelerator.state.fsdp_plugin.state_dict_type)
- ):
- raise ValueError("save_only_model option is not compatible with FSDP state dict type 'SHARDED_STATE_DICT'")
- # ---- Data Loading ----
- def get_train_dataloader(self) -> DataLoader:
- """
- Returns the training [`~torch.utils.data.DataLoader`].
- Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
- training if necessary) otherwise.
- Subclass and override this method if you want to inject some custom behavior.
- """
- if self.train_dataset is None:
- raise ValueError("Trainer: training requires a train_dataset.")
- return self._get_dataloader(
- dataset=self.train_dataset,
- description="Training",
- batch_size=self._train_batch_size,
- sampler_fn=self._get_train_sampler,
- is_training=True,
- )
- def get_eval_dataloader(self, eval_dataset: str | Dataset | None = None) -> DataLoader:
- """
- Returns the evaluation [`~torch.utils.data.DataLoader`].
- Subclass and override this method if you want to inject some custom behavior.
- Args:
- eval_dataset (`str` or `torch.utils.data.Dataset`, *optional*):
- If a `str`, will use `self.eval_dataset[eval_dataset]` as the evaluation dataset. If a `Dataset`, will override `self.eval_dataset` and must implement `__len__`. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed.
- """
- if eval_dataset is None and self.eval_dataset is None:
- raise ValueError("Trainer: evaluation requires an eval_dataset.")
- # If we have persistent workers, don't do a fork bomb especially as eval datasets
- # don't change during training
- dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
- if (
- hasattr(self, "_eval_dataloaders")
- and dataloader_key in self._eval_dataloaders
- and self.args.dataloader_persistent_workers
- ):
- return self._eval_dataloaders[dataloader_key]
- eval_dataset = (
- self.eval_dataset[eval_dataset]
- if isinstance(eval_dataset, str)
- else eval_dataset
- if eval_dataset is not None
- else self.eval_dataset
- )
- return self._get_dataloader(
- dataset=eval_dataset,
- description="Evaluation",
- batch_size=self.args.eval_batch_size,
- sampler_fn=self._get_eval_sampler,
- dataloader_key=dataloader_key,
- )
- def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
- """
- Returns the test [`~torch.utils.data.DataLoader`].
- Subclass and override this method if you want to inject some custom behavior.
- Args:
- test_dataset (`torch.utils.data.Dataset`, *optional*):
- The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the
- `model.forward()` method are automatically removed. It must implement `__len__`.
- """
- return self._get_dataloader(
- dataset=test_dataset,
- description="test",
- batch_size=self.args.eval_batch_size,
- sampler_fn=self._get_eval_sampler,
- )
- def num_examples(self, dataloader: DataLoader) -> int:
- """
- Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When
- dataloader.dataset does not exist or has no length, estimates as best it can
- """
- try:
- dataset = dataloader.dataset
- # Special case for IterableDatasetShard, we need to dig deeper
- if isinstance(dataset, IterableDatasetShard):
- return len(dataloader.dataset.dataset)
- return len(dataloader.dataset)
- except (NameError, AttributeError, TypeError): # no dataset or length, estimate by length of dataloader
- return len(dataloader) * self.args.per_device_train_batch_size
- def _get_dataloader(
- self,
- dataset: Dataset,
- description: str,
- batch_size: int,
- sampler_fn: Callable[[Dataset], torch.utils.data.Sampler] | None = None,
- is_training: bool = False,
- dataloader_key: str | None = None,
- ) -> DataLoader:
- """Create a [`~torch.utils.data.DataLoader`] from the given dataset."""
- data_collator = self.data_collator
- if is_datasets_available() and isinstance(dataset, datasets.Dataset):
- dataset = self._remove_unused_columns(dataset, description=description)
- else:
- data_collator = self._get_collator_with_removed_columns(self.data_collator, description=description)
- # MPS requrires forking if multiple workers are specified
- should_fork = torch.backends.mps.is_available() and self.args.dataloader_num_workers > 1
- dataloader_params = {
- "batch_size": batch_size,
- "collate_fn": data_collator,
- "num_workers": self.args.dataloader_num_workers,
- "pin_memory": self.args.dataloader_pin_memory,
- "persistent_workers": self.args.dataloader_persistent_workers,
- "multiprocessing_context": "fork" if should_fork else None,
- }
- if not isinstance(dataset, torch.utils.data.IterableDataset):
- if sampler_fn is not None:
- dataloader_params["sampler"] = sampler_fn(dataset)
- dataloader_params["drop_last"] = self.args.dataloader_drop_last
- dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
- if is_training:
- dataloader_params["worker_init_fn"] = partial(
- seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index
- )
- dataloader = self.accelerator.prepare(DataLoader(dataset, **dataloader_params))
- # Store the prepared dataloader for subsequent evaluations if using persistent workers.
- if dataloader_key is not None and self.args.dataloader_persistent_workers:
- if hasattr(self, "_eval_dataloaders"):
- self._eval_dataloaders[dataloader_key] = dataloader
- else:
- self._eval_dataloaders = {dataloader_key: dataloader}
- return dataloader
- def _get_train_sampler(self, train_dataset: Dataset | None = None) -> torch.utils.data.Sampler | None:
- """Return the training sampler based on `train_sampling_strategy`."""
- if train_dataset is None:
- train_dataset = self.train_dataset
- if train_dataset is None or not has_length(train_dataset):
- return None
- # Build the sampler.
- if self.args.train_sampling_strategy == "group_by_length":
- if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
- lengths = (
- train_dataset[self.args.length_column_name]
- if self.args.length_column_name in train_dataset.column_names
- else None
- )
- else:
- lengths = None
- model_input_name = (
- self.processing_class.model_input_names[0] if self.processing_class is not None else None
- )
- return LengthGroupedSampler(
- self.args.train_batch_size * self.args.gradient_accumulation_steps,
- dataset=train_dataset,
- lengths=lengths,
- model_input_name=model_input_name,
- )
- elif self.args.train_sampling_strategy == "sequential":
- return SequentialSampler(train_dataset)
- else:
- return RandomSampler(train_dataset)
- def _get_eval_sampler(self, eval_dataset: Dataset) -> torch.utils.data.Sampler | None:
- """Return the evaluation sampler, using sequential ordering when not distributed."""
- if eval_dataset is None or not has_length(eval_dataset):
- return None
- if self.args.train_sampling_strategy == "group_by_length":
- if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
- lengths = (
- eval_dataset[self.args.length_column_name]
- if self.args.length_column_name in eval_dataset.column_names
- else None
- )
- else:
- lengths = None
- model_input_name = (
- self.processing_class.model_input_names[0] if self.processing_class is not None else None
- )
- return LengthGroupedSampler(
- self.args.eval_batch_size,
- dataset=eval_dataset,
- lengths=lengths,
- model_input_name=model_input_name,
- )
- if self.args.world_size <= 1:
- return SequentialSampler(eval_dataset)
- else:
- return None
- def _set_signature_columns_if_needed(self) -> None:
- """Populate `_signature_columns` from the model's forward signature if not already set."""
- if self._signature_columns is None:
- # Inspect model forward signature to keep only the arguments it accepts.
- model_to_inspect = self.model
- if _is_peft_model(self.model):
- if hasattr(self.model, "get_base_model"):
- model_to_inspect = self.model.get_base_model()
- else:
- # PeftMixedModel do not provide a `get_base_model` method
- model_to_inspect = self.model.base_model.model
- signature = inspect.signature(model_to_inspect.forward)
- self._signature_columns = list(signature.parameters.keys())
- # Labels may be named label or label_ids, the default data collator handles that.
- self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
- def _remove_unused_columns(
- self, dataset: "datasets.Dataset", description: str | None = None
- ) -> "datasets.Dataset":
- """Remove dataset columns not accepted by the model's forward method."""
- if not self.args.remove_unused_columns:
- return dataset
- self._set_signature_columns_if_needed()
- signature_columns = self._signature_columns
- ignored_columns = list(set(dataset.column_names) - set(signature_columns))
- if len(ignored_columns) > 0:
- dset_description = "" if description is None else f"in the {description} set"
- logger.info(
- f"The following columns {dset_description} don't have a corresponding argument in "
- f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
- f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, "
- " you can safely ignore this message."
- )
- columns = [k for k in signature_columns if k in dataset.column_names]
- if len(columns) == 0:
- raise ValueError(
- f"No columns in the dataset match the model's forward method signature: ({', '.join(signature_columns)}). "
- f"The following columns have been ignored: [{', '.join(ignored_columns)}]. "
- "Please check the dataset and model. You may need to set `remove_unused_columns=False` in `TrainingArguments`."
- )
- if version.parse(datasets.__version__) < version.parse("1.4.0"):
- dataset.set_format(
- type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"]
- )
- return dataset
- else:
- return dataset.remove_columns(ignored_columns)
- def _get_collator_with_removed_columns(self, data_collator: Callable, description: str | None = None) -> Callable:
- """Wrap the data collator in a callable removing unused columns."""
- if not self.args.remove_unused_columns:
- return data_collator
- self._set_signature_columns_if_needed()
- signature_columns = self._signature_columns
- remove_columns_collator = RemoveColumnsCollator(
- data_collator=data_collator,
- signature_columns=signature_columns,
- logger=logger,
- description=description,
- model_name=self.model.__class__.__name__,
- )
- return remove_columns_collator
- # ---- Optimizer & Scheduler & Learning rate ----
- def create_optimizer_and_scheduler(self, num_training_steps: int) -> None:
- """
- Setup the optimizer and the learning rate scheduler.
- We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
- Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or
- `create_scheduler`) in a subclass.
- """
- self.create_optimizer()
- self.create_scheduler(num_training_steps=num_training_steps)
- def create_optimizer(self, model=None) -> torch.optim.Optimizer:
- """
- Setup the optimizer.
- We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
- Trainer's init through `optimizers`, or subclass and override this method in a subclass.
- Returns:
- `torch.optim.Optimizer`: The optimizer instance.
- """
- opt_model = self.model if model is None else model
- if self.optimizer is None:
- decay_parameters = self.get_decay_parameter_names(opt_model)
- optimizer_grouped_parameters = [
- {
- "params": [
- p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
- ],
- "weight_decay": self.args.weight_decay,
- },
- {
- "params": [
- p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
- ],
- "weight_decay": 0.0,
- },
- ]
- if self.optimizer_cls_and_kwargs is not None:
- optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
- else:
- optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model)
- # Check if this is a factory (for complex optimizers like Muon, Dion)
- # Factories are instantiated first, then called with (opt_model, **kwargs)
- if is_optimizer_factory(optimizer_cls):
- self.optimizer = optimizer_cls()(opt_model, **optimizer_kwargs)
- else:
- # Standard optimizer class instantiation
- # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
- # e.g. for GaLore optimizer.
- if "params" in optimizer_kwargs:
- optimizer_grouped_parameters = optimizer_kwargs.pop("params")
- # Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
- # e.g. for LOMO optimizer.
- if "model" in optimizer_kwargs:
- optimizer_grouped_parameters = optimizer_kwargs.pop("model")
- # For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
- # to avoid arguments conflicts.
- if "optimizer_dict" in optimizer_kwargs:
- optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")
- self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
- if "bitsandbytes" in str(optimizer_cls) and optimizer_kwargs.get("optim_bits", None) == 8:
- import bitsandbytes
- manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
- skipped = 0
- for module in opt_model.modules():
- if isinstance(module, nn.Embedding):
- skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
- logger.info(f"skipped {module}: {skipped / 2**20}M params")
- manager.register_module_override(module, "weight", {"optim_bits": 32})
- logger.debug(f"bitsandbytes: will optimize {module} in fp32")
- logger.info(f"skipped: {skipped / 2**20}M params")
- if is_sagemaker_mp_enabled():
- self.optimizer = smp.DistributedOptimizer(self.optimizer)
- return self.optimizer
- def create_scheduler(
- self, num_training_steps: int, optimizer: torch.optim.Optimizer | None = None
- ) -> torch.optim.lr_scheduler.LRScheduler:
- """
- Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
- passed as an argument.
- Args:
- num_training_steps (int): The number of training steps to do.
- Returns:
- `torch.optim.lr_scheduler.LRScheduler`: The learning rate scheduler instance.
- """
- if self.lr_scheduler is None:
- if optimizer is None:
- if is_sagemaker_mp_enabled() and smp.state.cfg.fp16:
- # If fp16 is enabled, we unwrap the optimizer
- optimizer = self.optimizer.optimizer
- else:
- optimizer = self.optimizer
- self.lr_scheduler = get_scheduler(
- self.args.lr_scheduler_type,
- optimizer=optimizer,
- num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
- num_training_steps=num_training_steps,
- scheduler_specific_kwargs=self.args.lr_scheduler_kwargs,
- )
- self._created_lr_scheduler = True
- return self.lr_scheduler
- @staticmethod
- def get_optimizer_cls_and_kwargs(args: TrainingArguments, model: PreTrainedModel | None = None) -> tuple[Any, Any]:
- """
- Returns the optimizer class and optimizer parameters based on the training arguments.
- Args:
- args (`transformers.training_args.TrainingArguments`):
- The training arguments for the training session.
- model (`PreTrainedModel`, *optional*):
- The model being trained. Required for some optimizers (GaLore, Apollo, LOMO).
- Returns:
- A tuple containing the optimizer class and a dictionary of optimizer keyword arguments.
- """
- ctx = OptimizerContext(
- args=args,
- model=model,
- optimizer_kwargs={"lr": args.learning_rate},
- adam_kwargs={
- "betas": (args.adam_beta1, args.adam_beta2),
- "eps": args.adam_epsilon,
- },
- optim_args=_parse_optim_args(args.optim_args),
- )
- handler = _OPTIMIZER_HANDLERS.get(args.optim)
- if handler is None:
- raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
- return handler(ctx)
- def get_decay_parameter_names(self, model: nn.Module) -> list[str]:
- """
- Get all parameter names that weight decay will be applied to.
- This function filters out parameters in two ways:
- 1. By layer type (instances of layers specified in ALL_LAYERNORM_LAYERS)
- 2. By parameter name patterns (containing 'bias', or variation of 'norm')
- """
- forbidden_name_patterns = [r"bias", r"layernorm", r"rmsnorm", r"(?:^|\.)norm(?:$|\.)", r"_norm(?:$|\.)"]
- decay_parameters = get_parameter_names(model, [nn.LayerNorm], forbidden_name_patterns)
- return decay_parameters
- def _get_learning_rate(self) -> float:
- """
- Returns the current learning rate from the scheduler.
- Handles DeepSpeed's dynamic loss scaling warmup period where `get_last_lr` may fail.
- """
- if self.is_deepspeed_enabled:
- # with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may
- # not run for the first few dozen steps while loss scale is too large, and thus during
- # that time `get_last_lr` will fail if called during that warm up stage, so work around it:
- try:
- last_lr = self.lr_scheduler.get_last_lr()[0]
- except AssertionError as e:
- if "need to call step" in str(e):
- logger.warning("tried to get lr value before scheduler/optimizer started stepping, returning lr=0")
- last_lr = 0
- else:
- raise
- else:
- if isinstance(self.lr_scheduler, (torch.optim.lr_scheduler.ReduceLROnPlateau, GreedyLR)):
- last_lr = self.optimizer.param_groups[0]["lr"]
- else:
- last_lr = self.lr_scheduler.get_last_lr()[0]
- if torch.is_tensor(last_lr):
- last_lr = last_lr.item()
- return last_lr
- # ---- Training ----
- def train(
- self,
- resume_from_checkpoint: str | bool | None = None,
- trial: "optuna.Trial | dict[str, Any] | None" = None,
- ignore_keys_for_eval: list[str] | None = None,
- ) -> TrainOutput:
- """
- Main training entry point.
- Args:
- resume_from_checkpoint (`str` or `bool`, *optional*):
- If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a
- `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
- of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here.
- trial (`optuna.Trial` or `dict[str, Any]`, *optional*):
- The trial run or the hyperparameter dictionary for hyperparameter search.
- ignore_keys_for_eval (`list[str]`, *optional*)
- A list of keys in the output of your model (if it is a dictionary) that should be ignored when
- gathering predictions for evaluation during the training.
- Returns:
- [`~trainer_utils.TrainOutput`]: Object containing the global step count, training loss, and metrics.
- """
- if resume_from_checkpoint is False:
- resume_from_checkpoint = None
- # memory metrics - must set up as early as possible
- self._memory_tracker.start()
- args = self.args
- self.is_in_train = True
- # Model re-init
- if self.model_init is not None:
- # Seed must be set before instantiating the model when using model_init.
- enable_full_determinism(args.seed) if args.full_determinism else set_seed(args.seed)
- self.model = self.call_model_init(trial)
- # Reinitializes optimizer and scheduler
- self.optimizer, self.lr_scheduler = None, None
- if self.place_model_on_device:
- self._move_model_to_device(self.model, args.device)
- self.model_wrapped = self.model
- if self.args.use_liger_kernel:
- apply_liger_kernel(self.model, self.args.liger_kernel_config)
- # When fp16/bf16 full eval is enabled, __init__ skips device placement so that
- # evaluation_loop can cast dtype and move in one step. Move the model now for training.
- if (args.fp16_full_eval or args.bf16_full_eval) and not self.is_model_parallel and self.model_init is None:
- self._move_model_to_device(self.model, args.device)
- # Activate gradient checkpointing if needed
- if args.gradient_checkpointing:
- self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs)
- # If the model uses a tokenizer, it may have a new tokens for fine-tuning purposes.
- if isinstance(self.processing_class, (PreTrainedTokenizerBase, ProcessorMixin)) and hasattr(
- self.model, "config"
- ):
- align_special_tokens(self.model, self.processing_class)
- # Attach NEFTune hooks if necessary
- if self.neftune_noise_alpha is not None:
- self.neftune_hook_handle = activate_neftune(self.model, self.neftune_noise_alpha, self.accelerator)
- # This might change the seed so needs to run first.
- self._hp_search_setup(trial)
- if DebugOption.UNDERFLOW_OVERFLOW in args.debug:
- if args.n_gpu > 1:
- # nn.DataParallel(model) replicates the model, creating new variables and module
- # references registered here no longer work on other gpus, breaking the module
- raise ValueError(
- "Currently --debug underflow_overflow is not supported under DP. Please use DDP with torchrun"
- )
- else:
- DebugUnderflowOverflow(self.model)
- # Load potential model checkpoint
- if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
- resume_from_checkpoint = get_last_checkpoint(args.output_dir)
- if resume_from_checkpoint is None:
- raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
- if resume_from_checkpoint is not None:
- # Load model checkpoint before accelerator.prepare() for regular models,
- # so that buffers and parameters are on the right device after prepare.
- # Deepspeed/FSDP models are loaded after prepare in _prepare_for_training.
- if not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled and not self.is_fsdp_enabled:
- self._load_from_checkpoint(resume_from_checkpoint)
- state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
- if state.train_batch_size is not None and args.auto_find_batch_size:
- # Only restore the checkpoint's train_batch_size when using auto_find_batch_size,
- self._train_batch_size = state.train_batch_size
- inner_training_loop = find_executable_batch_size(
- self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
- )
- # Disable progress bars when uploading models during checkpoints to avoid polluting stdout
- ctx = suppress_progress_bars() if args.push_to_hub else contextlib.nullcontext()
- with ctx:
- return inner_training_loop(
- args=args,
- resume_from_checkpoint=resume_from_checkpoint,
- trial=trial,
- ignore_keys_for_eval=ignore_keys_for_eval,
- )
- def _inner_training_loop(
- self,
- batch_size: int | None = None,
- args: TrainingArguments | None = None,
- resume_from_checkpoint: str | None = None,
- trial: "optuna.Trial | dict[str, Any] | None" = None,
- ignore_keys_for_eval: list[str] | None = None,
- ) -> TrainOutput:
- """Run the actual training loop: forward, backward, optimizer step, logging, and checkpointing."""
- # reset everything
- self.accelerator.free_memory()
- if args.auto_find_batch_size:
- self._update_auto_batch_size(batch_size)
- # Data loader and number of training steps
- train_dataloader = self.get_train_dataloader()
- if self.is_fsdp_xla_v2_enabled:
- train_dataloader = tpu_spmd_dataloader(train_dataloader)
- # Setting up training control variables:
- (
- num_train_epochs,
- num_update_steps_per_epoch,
- num_examples,
- num_train_samples,
- total_train_batch_size,
- steps_in_epoch,
- max_steps,
- ) = self.set_initial_training_values(args, train_dataloader)
- epochs_trained, steps_trained_in_current_epoch = self._init_training_state(
- max_steps, num_update_steps_per_epoch, num_train_epochs, resume_from_checkpoint, trial
- )
- model, train_dataloader = self._prepare_for_training(max_steps, train_dataloader, resume_from_checkpoint)
- # Train!
- logger.info("***** Running training *****")
- logger.info(f" Num examples = {num_examples:,}")
- logger.info(f" Num Epochs = {num_train_epochs:,}")
- logger.info(f" Num update steps per epoch = {num_update_steps_per_epoch:,}")
- logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
- if self.args.per_device_train_batch_size != self._train_batch_size:
- logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}")
- logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
- logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
- logger.info(f" Total optimization steps = {max_steps:,}")
- logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")
- if resume_from_checkpoint is not None:
- logger.info(
- f" Resuming training from checkpoint with epoch {epochs_trained} and global step {self.state.global_step}"
- )
- if not self.args.ignore_data_skip:
- logger.info(
- f" Fast-forwarding the dataloader past {epochs_trained} epochs and"
- f" {steps_trained_in_current_epoch} batches to resume from the exact training state."
- )
- start_time = time.time()
- # needed to calculate tokens/s
- self._initial_num_input_tokens_seen = self.state.num_input_tokens_seen
- # Logging state: _tr_loss accumulates on-device between logging steps (avoiding costly .item() syncs
- # on TPUs), then gets drained into _total_loss_scalar at each logging step.
- self._tr_loss = torch.tensor(0.0, device=args.device)
- self._total_loss_scalar = 0.0
- self._globalstep_last_logged = self.state.global_step
- model.zero_grad()
- self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
- if args.eval_on_start:
- self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)
- for epoch in range(epochs_trained, num_train_epochs):
- self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control)
- self._run_epoch(
- model=model,
- epoch=epoch,
- train_dataloader=train_dataloader,
- steps_in_epoch=steps_in_epoch,
- num_update_steps_per_epoch=num_update_steps_per_epoch,
- trial=trial,
- ignore_keys_for_eval=ignore_keys_for_eval,
- start_time=start_time,
- resume_from_checkpoint=resume_from_checkpoint,
- epochs_trained=epochs_trained,
- steps_trained_in_current_epoch=steps_trained_in_current_epoch,
- )
- if self.control.should_training_stop:
- break
- return self._finalize_training(trial, num_train_samples, start_time)
- def _init_training_state(
- self, max_steps, num_update_steps_per_epoch, num_train_epochs, resume_from_checkpoint, trial
- ) -> tuple[int, int]:
- """Initialize TrainerState, optionally restoring from checkpoint. Returns (epochs_trained, steps_trained_in_current_epoch)."""
- self.state = TrainerState(
- stateful_callbacks=[
- cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
- ]
- )
- self.state.is_hyper_param_search = trial is not None
- self.state.train_batch_size = self._train_batch_size
- self.state.compute_steps(self.args, max_steps)
- epochs_trained = 0
- steps_trained_in_current_epoch = 0
- if resume_from_checkpoint is not None and os.path.isfile(
- os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
- ):
- self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
- compare_trainer_and_checkpoint_args(self.args, self.state)
- self._load_callback_state()
- epochs_trained = int(self.state.global_step // num_update_steps_per_epoch)
- if not self.args.ignore_data_skip:
- steps_trained_in_current_epoch = self.state.global_step % num_update_steps_per_epoch
- steps_trained_in_current_epoch *= self.args.gradient_accumulation_steps
- self.state.init_training_references(self, max_steps, num_train_epochs, trial)
- return epochs_trained, steps_trained_in_current_epoch
- def _prepare_for_training(self, max_steps, train_dataloader, resume_from_checkpoint):
- """Wrap model, create optimizer and scheduler, and run accelerator.prepare. Returns (model, train_dataloader)."""
- delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
- # Can't delay optimizer creation when using FSDP2: https://github.com/huggingface/accelerate/blob/3f636d626063ffcf9a337c7d3624d61b7d187d59/src/accelerate/accelerator.py#L1404
- is_fsdp2 = self.is_fsdp_enabled and (getattr(self.accelerator.state.fsdp_plugin, "fsdp_version", 1) == 2)
- if is_fsdp2:
- delay_optimizer_creation = False
- # We need to reset the scheduler, as its parameters may be different on subsequent calls
- if self._created_lr_scheduler:
- self.lr_scheduler = None
- self._created_lr_scheduler = False
- if self.is_deepspeed_enabled:
- self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)
- if not delay_optimizer_creation:
- self.create_optimizer()
- # Pass `self.model_wrapped` so that `_wrap_model` can detect if the model is already
- # wrapped (e.g. in DataParallel) on subsequent `train()` calls and avoid double wrapping.
- model = self._wrap_model(self.model_wrapped)
- # If the model is wrapped, don't use `accelerator.prepare`
- # this is for unhandled cases in accelerate such as FSDP-XLA, SageMaker MP/DP, DataParallel
- use_accelerator_prepare = model is self.model
- # prepare using `accelerator` prepare
- if use_accelerator_prepare:
- if delay_optimizer_creation:
- # TODO: check if we can move this somewhere else
- if self.is_fsdp_enabled and _is_peft_model(self.model):
- update_fsdp_plugin_peft(self.model, self.accelerator)
- # we only prepare the model as we don't have an optimizer
- model = self.accelerator.prepare(self.model)
- # using the model we prepared to create the optimizer
- self.create_optimizer(model)
- self.optimizer = self.accelerator.prepare(self.optimizer)
- elif self.is_deepspeed_enabled and type(self.lr_scheduler).__name__ == "DummyScheduler":
- model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
- self.model, self.optimizer, self.lr_scheduler
- )
- else:
- model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
- else:
- self.optimizer = self.accelerator.prepare(self.optimizer)
- # Create scheduler now that the optimizer won't change anymore
- self.create_scheduler(num_training_steps=max_steps)
- # updating self.model_wrapped
- self.model_wrapped = model
- if self.is_fsdp_enabled or self.is_fsdp_xla_enabled:
- # breaking convention for FSDP model
- # TODO: check if this is really needed
- self.model = self.model_wrapped = model
- # backward compatibility
- # TODO: check if we really need this
- if self.is_deepspeed_enabled:
- self.deepspeed = self.model_wrapped
- # Important: at this point:
- # self.model is the Transformers Model except when we are using FSDP
- # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model),
- # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc.
- if self.is_fsdp_enabled:
- # Fix `got mixed torch.Tensor and DTensor` error in model.generate() for FSDP2 with LoRA
- if hasattr(self.model, "generate"):
- dist.fsdp.register_fsdp_forward_method(self.model, "generate")
- # since DataLoader was Accelerate prepared w/o a model arg in the same call, we now have to complete the DL wrapping for ALST/UlyssesSP, after model has been prepared
- pc = getattr(self.accelerator, "parallelism_config", None)
- if pc is not None and pc.sp_backend == "deepspeed" and pc.sp_enabled:
- train_dataloader = self.accelerator.deepspeed_ulysses_dl_adapter(train_dataloader, model)
- # load checkpoint
- if resume_from_checkpoint is not None:
- if self.is_deepspeed_enabled:
- deepspeed_load_checkpoint(
- self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model)
- )
- elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled:
- self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped)
- self._load_optimizer_and_scheduler(resume_from_checkpoint)
- self._load_scaler(resume_from_checkpoint)
- # Update the references for the callback_handler
- for attr in ("model", "optimizer", "lr_scheduler"):
- setattr(self.callback_handler, attr, getattr(self, attr))
- self.callback_handler.train_dataloader = train_dataloader
- return model, train_dataloader
- def _run_epoch(
- self,
- model,
- epoch,
- train_dataloader,
- steps_in_epoch,
- num_update_steps_per_epoch,
- trial,
- ignore_keys_for_eval,
- start_time,
- resume_from_checkpoint,
- epochs_trained,
- steps_trained_in_current_epoch,
- ):
- """Run one full pass over the dataloader."""
- step = -1
- grad_norm = None
- learning_rate = None
- rng_to_sync = False
- # Handle resumption from checkpoint: skip already-trained batches in the resumed epoch
- num_update_steps_trained = 0
- if epoch == epochs_trained and resume_from_checkpoint is not None:
- if steps_trained_in_current_epoch > 0 and not self.args.ignore_data_skip:
- train_dataloader = skip_first_batches(train_dataloader, steps_trained_in_current_epoch)
- step = steps_trained_in_current_epoch - 1
- num_update_steps_trained = steps_trained_in_current_epoch // self.args.gradient_accumulation_steps
- rng_to_sync = True
- elif steps_trained_in_current_epoch == 0:
- self._load_rng_state(resume_from_checkpoint)
- if hasattr(train_dataloader, "set_epoch"):
- train_dataloader.set_epoch(epoch)
- epoch_iterator = iter(train_dataloader)
- # We chunkify the epoch iterator into gradient accumulation steps `n` batches
- remainder = steps_in_epoch % self.args.gradient_accumulation_steps
- if remainder == 0:
- remainder = self.args.gradient_accumulation_steps
- # Outer loop: one iteration per optimizer step. Each iteration prefetches
- # `gradient_accumulation_steps` batches (fewer for the last step if the epoch
- # doesn't divide evenly).
- for update_step in range(num_update_steps_trained, num_update_steps_per_epoch):
- num_batches = (
- self.args.gradient_accumulation_steps if update_step != (num_update_steps_per_epoch - 1) else remainder
- )
- batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, self.args.device)
- # This is used to correctly scale the loss when the last accumulation step has fewer batches.
- # Not used if `num_items_in_batch` is not None.
- self.current_gradient_accumulation_steps = len(batch_samples)
- # need to sync after if we skipped the batches in `get_batch_samples` for shuffle order reason
- if rng_to_sync:
- self._load_rng_state(resume_from_checkpoint)
- rng_to_sync = False
- # Inner loop: forward + backward for each micro-batch. Gradients are
- # accumulated without syncing until the last micro-batch, then we clip,
- # step the optimizer, and log/save/evaluate.
- for i, inputs in enumerate(batch_samples):
- step += 1
- do_sync_step = (step + 1) % self.args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch
- # Since we perform prefetching, we need to manually set sync_gradients
- self.accelerator.gradient_state._set_sync_gradients(do_sync_step)
- if step % self.args.gradient_accumulation_steps == 0:
- self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control)
- # We sync the gradients in the following cases: 1. sync_each_batch set to True 2. Using deepspeed 3. when we are at the last batch sample
- if (
- self.accelerator.gradient_state.plugin_kwargs.get("sync_each_batch", False)
- or self.accelerator.distributed_type == DistributedType.DEEPSPEED
- or i == len(batch_samples) - 1
- ):
- sync_context = contextlib.nullcontext
- else:
- sync_context = functools.partial(self.accelerator.no_sync, model=model)
- with sync_context():
- tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
- if (
- self.args.logging_nan_inf_filter
- and not is_torch_xla_available()
- and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
- ):
- # if loss is nan or inf simply add the average of previous logged losses
- self._tr_loss += self._tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
- else:
- if self._tr_loss.device != tr_loss_step.device:
- raise ValueError(
- f"Calculated loss must be on the original device: {self._tr_loss.device} but device in use is {tr_loss_step.device}"
- )
- self._tr_loss += tr_loss_step
- self.current_flos += float(self.floating_point_ops(inputs))
- self._track_num_input_tokens(inputs)
- if do_sync_step:
- grad_norm = None
- if self.args.max_grad_norm > 0:
- grad_norm = self._clip_grad_norm(model)
- grad_norm = self._get_grad_norm(model, grad_norm=grad_norm)
- self.control = self.callback_handler.on_pre_optimizer_step(self.args, self.state, self.control)
- self.optimizer.step()
- self.control = self.callback_handler.on_optimizer_step(self.args, self.state, self.control)
- # get leaning rate before update
- learning_rate = self._get_learning_rate()
- if not self.accelerator.optimizer_step_was_skipped:
- # Delay optimizer scheduling until metrics are generated
- if not isinstance(self.lr_scheduler, (torch.optim.lr_scheduler.ReduceLROnPlateau, GreedyLR)):
- self.lr_scheduler.step()
- model.zero_grad()
- self.state.global_step += 1
- self.state.epoch = epoch + (step + 1) / steps_in_epoch
- self.control = self.callback_handler.on_step_end(self.args, self.state, self.control)
- self._maybe_log_save_evaluate(
- self._tr_loss,
- grad_norm,
- model,
- trial,
- epoch,
- ignore_keys_for_eval,
- start_time,
- learning_rate=learning_rate,
- )
- else:
- self.control = self.callback_handler.on_substep_end(self.args, self.state, self.control)
- if self.control.should_epoch_stop or self.control.should_training_stop:
- break
- if self.control.should_epoch_stop or self.control.should_training_stop:
- break
- # PyTorch/XLA relies on the dataloader to insert mark_step each iteration.
- # When we break out of the loop early, we flush the pending graph manually.
- if is_torch_xla_available():
- xm.mark_step()
- if step < 0:
- logger.warning(
- "There seems not to be a single sample in your epoch_iterator, stopping training at step"
- f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
- f" num_steps ({self.state.max_steps}) higher than the number of available samples."
- )
- self.control.should_training_stop = True
- self.control = self.callback_handler.on_epoch_end(self.args, self.state, self.control)
- self._maybe_log_save_evaluate(
- self._tr_loss,
- grad_norm,
- model,
- trial,
- epoch,
- ignore_keys_for_eval,
- start_time,
- learning_rate=learning_rate,
- )
- def _finalize_training(self, trial, num_train_samples, start_time):
- """Finalize training: metrics, best-model loading, cleanup. Returns TrainOutput."""
- logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
- # add remaining tr_loss
- self._total_loss_scalar += self._tr_loss.item()
- effective_global_step = max(self.state.global_step, 0.001) # Avoid ZeroDivisionError
- train_loss = self._total_loss_scalar / effective_global_step
- metrics = speed_metrics(
- "train",
- start_time,
- num_samples=num_train_samples,
- num_steps=self.state.max_steps,
- )
- self.store_flos()
- metrics["total_flos"] = self.state.total_flos
- metrics["train_loss"] = train_loss
- self._memory_tracker.stop_and_update_metrics(metrics)
- self.log(metrics)
- if self.args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
- self._load_best_model()
- checkpoints_sorted = sort_checkpoints(
- output_dir=self._get_output_dir(trial), best_model_checkpoint=self.state.best_model_checkpoint
- )
- # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
- if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
- for checkpoint in checkpoints_sorted:
- if not os.path.samefile(checkpoint, self.state.best_model_checkpoint):
- logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
- shutil.rmtree(checkpoint, ignore_errors=True)
- self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
- # Wait for the checkpoint to be uploaded.
- self._finish_current_push()
- # After training we make sure to retrieve back the original forward pass method
- # for the embedding layer by removing the forward post hook.
- if self.neftune_noise_alpha is not None:
- deactivate_neftune(self.model, self.neftune_hook_handle, self.accelerator)
- self.is_in_train = False
- return TrainOutput(self.state.global_step, train_loss, metrics)
- def training_step(
- self,
- model: nn.Module,
- inputs: dict[str, torch.Tensor | Any],
- num_items_in_batch: torch.Tensor | int | None = None,
- ) -> torch.Tensor:
- """
- Perform a training step on a batch of inputs.
- Subclass and override to inject custom behavior.
- Args:
- model (`nn.Module`):
- The model to train.
- inputs (`dict[str, torch.Tensor | Any]`):
- The inputs and targets of the model.
- The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
- argument `labels`. Check your model's documentation for all accepted arguments.
- Return:
- `torch.Tensor`: The tensor with training loss on this batch.
- """
- # Prepare buffers for context parallelism
- cp_context, inputs = self._prepare_context_parallel_inputs(model, inputs)
- # Context manager is no-op if CP isn't enabled
- with cp_context():
- model.train()
- if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
- self.optimizer.train()
- inputs = self._prepare_inputs(inputs)
- if is_sagemaker_mp_enabled():
- loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
- return loss_mb.reduce_mean().detach().to(self.args.device)
- with self.compute_loss_context_manager():
- loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
- del inputs
- if (
- self.args.torch_empty_cache_steps is not None
- and self.state.global_step % self.args.torch_empty_cache_steps == 0
- ):
- clear_device_cache()
- kwargs = {}
- # For LOMO optimizers you need to explicitly use the learning rate
- if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
- kwargs["learning_rate"] = self._get_learning_rate()
- if self.args.n_gpu > 1:
- loss = loss.mean() # mean() to average on multi-gpu parallel training
- # Finally we need to normalize the loss for reporting if GA loss bug is not fixed during compute loss
- if (not self.model_accepts_loss_kwargs or num_items_in_batch is None) and self.compute_loss_func is None:
- # If the model does not accept loss kwargs, we need to normalize the loss by the number of gradient accumulation steps
- loss = loss / self.current_gradient_accumulation_steps
- # Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled
- # https://github.com/huggingface/transformers/pull/35808
- if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
- kwargs["scale_wrt_gas"] = False
- self.accelerator.backward(loss, **kwargs)
- return loss.detach()
- def compute_loss(
- self,
- model: nn.Module,
- inputs: dict[str, torch.Tensor | Any],
- return_outputs: bool = False,
- num_items_in_batch: torch.Tensor | int | None = None,
- ) -> torch.Tensor | tuple[torch.Tensor, Any]:
- """
- How the loss is computed by Trainer. By default, all models return the loss in the first element.
- Args:
- model (`nn.Module`):
- The model to compute the loss for.
- inputs (`dict[str, torch.Tensor | Any]`):
- The input data for the model.
- return_outputs (`bool`, *optional*, defaults to `False`):
- Whether to return the model outputs along with the loss.
- num_items_in_batch (Optional[torch.Tensor], *optional*):
- The number of items in the batch. If not passed, the loss is computed
- using the default batch size reduction logic.
- Returns:
- The loss of the model along with its output if return_outputs was set to True
- Subclass and override for custom behavior. If you are not using `num_items_in_batch` when computing your loss,
- make sure to overwrite `self.model_accepts_loss_kwargs` to `False`. Otherwise, the loss calculation might be slightly inaccurate when performing gradient accumulation.
- """
- pc = getattr(self.accelerator, "parallelism_config", None)
- if pc is not None and pc.sp_backend == "deepspeed" and pc.sp_enabled and self.model.training:
- return deepspeed_sp_compute_loss(self.accelerator, model, inputs, return_outputs, pc)
- if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs:
- labels = inputs.pop("labels")
- else:
- labels = None
- if self.model_accepts_loss_kwargs:
- kwargs = {}
- if num_items_in_batch is not None:
- kwargs["num_items_in_batch"] = num_items_in_batch
- inputs = {**inputs, **kwargs}
- outputs = model(**inputs)
- # User-defined compute_loss function
- if self.compute_loss_func is not None:
- if labels is None:
- logger.warning(
- "Trainer: `compute_loss_func` is defined but `labels=None`. "
- "Your custom loss function will still be called with labels=None. "
- )
- loss = self.compute_loss_func(
- outputs,
- labels,
- num_items_in_batch=num_items_in_batch,
- )
- # Default HF loss handling (label smoothing) if no custom loss function
- elif labels is not None:
- unwrapped_model = self.accelerator.unwrap_model(model)
- model_name = (
- unwrapped_model.base_model.model._get_name()
- if _is_peft_model(unwrapped_model)
- else unwrapped_model._get_name()
- )
- if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
- loss = self.label_smoother(outputs, labels, shift_labels=True)
- else:
- loss = self.label_smoother(outputs, labels)
- else:
- if isinstance(outputs, dict) and "loss" not in outputs:
- raise ValueError(
- "The model did not return a loss from the inputs, only the following keys: "
- f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
- )
- # We don't use .loss here since the model may return tuples instead of ModelOutput.
- loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
- if (
- self.args.average_tokens_across_devices
- and (self.model_accepts_loss_kwargs or self.compute_loss_func)
- and num_items_in_batch is not None
- ):
- loss *= self.accelerator.num_processes if self.args.n_gpu <= 1 else self.args.n_gpu
- return (loss, outputs) if return_outputs else loss
- def compute_loss_context_manager(self) -> contextlib.ExitStack:
- """
- A helper wrapper to group together context managers.
- """
- ctx_stack = contextlib.ExitStack()
- autocast_ctx = self.autocast_smart_context_manager()
- if not isinstance(autocast_ctx, contextlib.nullcontext):
- ctx_stack.enter_context(autocast_ctx)
- return ctx_stack
- def autocast_smart_context_manager(self, cache_enabled: bool | None = True) -> contextlib.AbstractContextManager:
- """
- A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
- arguments, depending on the situation. We rely on accelerate for autocast, hence we do nothing here.
- """
- return contextlib.nullcontext()
- def _maybe_log_save_evaluate(
- self,
- tr_loss: torch.Tensor,
- grad_norm: torch.Tensor | float | None,
- model: nn.Module,
- trial: "optuna.Trial | dict[str, Any] | None",
- epoch: float,
- ignore_keys_for_eval: list[str] | None,
- start_time: float,
- learning_rate: float | None = None,
- ) -> None:
- """Log metrics, run evaluation, and save checkpoints if the current training state requires it."""
- if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
- if is_torch_xla_available():
- xm.mark_step()
- logs: dict[str, float] = {}
- # all_gather + mean() to get average loss over all processes
- tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).mean().item()
- # reset tr_loss to zero
- tr_loss -= tr_loss
- logs["loss"] = tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged)
- if grad_norm is not None:
- logs["grad_norm"] = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm
- if learning_rate is not None:
- logs["learning_rate"] = learning_rate
- else:
- logs["learning_rate"] = self._get_learning_rate()
- self._total_loss_scalar += tr_loss_scalar
- self._globalstep_last_logged = self.state.global_step
- self.store_flos()
- self.log(logs, start_time)
- metrics = None
- if self.control.should_evaluate:
- metrics = self._evaluate(trial, ignore_keys_for_eval)
- is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)
- if self.args.save_strategy == SaveStrategy.BEST:
- self.control.should_save = is_new_best_metric
- if self.control.should_save:
- self._save_checkpoint(model, trial)
- self.control = self.callback_handler.on_save(self.args, self.state, self.control)
- # ---- Training Utilites ----
- def get_batch_samples(
- self, epoch_iterator: Iterator, num_batches: int, device: torch.device
- ) -> tuple[list, torch.Tensor | int | None]:
- """
- Collects a specified number of batches from the epoch iterator and optionally counts the number of items in the batches to properly scale the loss.
- """
- batch_samples = []
- for _ in range(num_batches):
- try:
- batch_samples.append(next(epoch_iterator))
- except StopIteration:
- break
- num_items_in_batch = self._get_num_items_in_batch(batch_samples, device)
- return batch_samples, num_items_in_batch
- def _get_num_items_in_batch(self, batch_samples: list, device: torch.device) -> torch.Tensor | int | None:
- """
- Counts the number of items in the batches to properly scale the loss.
- Args:
- batch_samples (`list`): List of batches
- device (`torch.device`): The device on which the number of items in the batch should be.
- Returns:
- None if the number of items in the batch doesn't need to be computed else the number of items in the batch
- """
- num_items_in_batch = None
- count_num_items_in_batch = (
- len(batch_samples) > 0
- and "labels" in batch_samples[0]
- and (
- # num_items_in_batch is passed to model forward
- # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/trainer.py#L3757
- self.model_accepts_loss_kwargs
- # num_items_in_batch is passed to compute_loss_func
- # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/trainer.py#L3773
- or self.compute_loss_func is not None
- # num_items_in_batch is also verified if (self.model_accepts_loss_kwargs or self.compute_loss_func)
- # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/trainer.py#L3790
- )
- )
- if count_num_items_in_batch:
- # For now we don't support object detection
- try:
- num_items_in_batch = sum((batch["labels"].ne(-100)).sum() for batch in batch_samples)
- except (TypeError, AttributeError):
- pass
- if num_items_in_batch is not None:
- if self.args.average_tokens_across_devices:
- if self.args.world_size > 1:
- num_items_in_batch = self.accelerator.gather(num_items_in_batch.to(device)).sum()
- elif self.args.n_gpu > 1:
- # In DP case, if we don't average, we need to divide by the number of gpu. This is the simplest approximation.
- # Otherwise, we would have to scatter labels and calculate num_items_in_batch for each gpu.
- num_items_in_batch = num_items_in_batch // self.args.n_gpu
- if torch.is_tensor(num_items_in_batch):
- num_items_in_batch = num_items_in_batch.to(device)
- if self.args.n_gpu > 1 and num_items_in_batch.dim() == 0:
- # In the DataParallel case, convert the scalar tensor into a 2-dim tensor with the same value repeated
- num_items_in_batch = num_items_in_batch.unsqueeze(0).expand(self.args.n_gpu, -1)
- # Divide by number of devices with the same batch
- if pc := getattr(self.accelerator, "parallelism_config", None):
- num_items_in_batch = num_items_in_batch // pc.non_data_parallel_size
- return num_items_in_batch
- def _prepare_input(self, data: torch.Tensor | Any) -> torch.Tensor | Any:
- """
- Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
- """
- if isinstance(data, Mapping):
- return type(data)({k: self._prepare_input(v) for k, v in data.items()})
- elif isinstance(data, (tuple, list)):
- return type(data)(self._prepare_input(v) for v in data)
- elif isinstance(data, torch.Tensor):
- kwargs = {"device": self.args.device}
- if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)):
- # NLP models inputs are int/uint and those get adjusted to the right dtype of the
- # embedding. Other models such as wav2vec2's inputs are already float and thus
- # may need special handling to match the dtypes of the model
- kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()})
- return data.to(**kwargs)
- return data
- def _prepare_inputs(self, inputs: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]:
- """
- Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
- handling potential state.
- """
- inputs = self._prepare_input(inputs)
- if len(inputs) == 0:
- raise ValueError(
- "The batch received was empty, your model won't be able to train on it. Double-check that your "
- f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}."
- )
- return inputs
- def _prepare_context_parallel_inputs(
- self, model: nn.Module, inputs: dict[str, torch.Tensor | Any]
- ) -> tuple[Callable, dict[str, torch.Tensor | Any]]:
- """
- Prepare inputs for context parallelism by setting up buffers and validation.
- Args:
- model: The model being trained
- inputs: Input tensors to prepare
- Returns:
- tuple: (context_manager, prepared_inputs) where context_manager is either
- the context parallelism wrapper or a no-op context
- """
- if (
- getattr(self.accelerator, "parallelism_config", None) is not None
- and self.accelerator.parallelism_config.cp_enabled
- ):
- if self.accelerator.parallelism_config.cp_backend == "torch":
- if hasattr(model, "config"):
- if model.config._attn_implementation != "sdpa":
- raise ValueError(
- f"Context parallelism is supported only with SDPA attention, you are using {model.config._attn_implementation}."
- )
- if "shift_labels" not in inputs:
- logger.warning_once("Shift labels not found in the inputs, shifting manually")
- if "labels" in inputs:
- _ignore_index = -100
- labels = nn.functional.pad(inputs["labels"], (0, 1), value=_ignore_index)
- inputs["shift_labels"] = labels[:, 1:].contiguous()
- # note: we don't do anything for accelerator.parallelism_config.sp_backend == "deepspeed" since:
- # - accelerator.parallelism_config performs the `model.config._attn_implementation` checks already and it supports more than `dspa`
- # - UlyssesSPDataLoaderAdapter called from Accelerate performs the `shift_label` creation - must not interfere
- # - position_ids generation should be done by HF Trainer if it wasn't done by the user
- if "position_ids" not in inputs:
- logger.warning_once("Position IDs not found in the inputs, generating manually")
- inputs["position_ids"] = torch.arange(
- inputs["input_ids"].size(1), device=inputs["input_ids"].device
- ).expand(inputs["input_ids"].size(0), -1)
- buffers = []
- buffer_seq_dims = []
- if "input_ids" in inputs:
- buffers.append(inputs["input_ids"])
- buffer_seq_dims.append(1) # Sequence dimension
- if "labels" in inputs:
- buffers.append(inputs["labels"])
- buffer_seq_dims.append(1)
- if "shift_labels" in inputs:
- buffers.append(inputs["shift_labels"])
- buffer_seq_dims.append(1)
- # Add attention_mask to buffers for context parallel splitting (only if causal)
- if "attention_mask" in inputs:
- # Only validate causal mask once for performance
- if not getattr(self, "_attn_mask_causal_checked", False):
- # Context parallel currently doesn't support other masks than causal
- # Accelerate applies hooks to replace mask with is_causal arg in SDPA
- # Check if the mask is really causal and if not throw an error
- attention_mask = inputs["attention_mask"]
- if not is_attention_mask_causal(attention_mask):
- raise ValueError(
- "Context parallelism only supports causal attention masks. "
- "The provided attention_mask is not causal. "
- "Please ensure your data uses causal masking (lower triangular) "
- "or remove the attention_mask to use the model's default causal masking."
- )
- self._attn_mask_causal_checked = True
- if self._attn_mask_causal_checked:
- # Add to buffers only after validation (or if validation already passed)
- attention_mask = inputs["attention_mask"]
- if attention_mask.dim() == 2:
- buffers.append(attention_mask)
- buffer_seq_dims.append(1)
- else:
- # Other dimensionality; keep as-is without sharding to avoid incorrect splits
- pass
- # Include position_ids in context parallelism splitting
- if "position_ids" in inputs and inputs["position_ids"] is not None:
- buffers.append(inputs["position_ids"])
- buffer_seq_dims.append(1)
- return partial(
- self.accelerator.maybe_context_parallel,
- buffers=buffers,
- buffer_seq_dims=buffer_seq_dims,
- no_restore_buffers=set(buffers),
- ), inputs
- return contextlib.nullcontext, inputs
- def set_initial_training_values(
- self, args: TrainingArguments, dataloader: DataLoader
- ) -> tuple[int, int, int, int, int, int | None, int]:
- """
- Calculates and returns the following values:
- - `num_train_epochs`
- - `num_update_steps_per_epoch`
- - `num_examples`
- - `num_train_samples`
- - `total_train_batch_size`
- - `steps_in_epoch` (total batches per epoch)
- - `max_steps`
- """
- # Case 1: we rely on `args.max_steps` first
- max_steps = args.max_steps
- # If max_steps is negative, we use the number of epochs to determine the number of total steps later
- epoch_based = max_steps < 0
- len_dataloader = len(dataloader) if has_length(dataloader) else None
- total_train_batch_size = self.get_total_train_batch_size(args)
- # Account for Sequence Parallelism (SP) dataloader adapter's effect
- sp_size = self.get_sp_size()
- if sp_size > 1 and len_dataloader is not None:
- len_dataloader = len_dataloader * sp_size
- # Case 2: We have a dataloader length and can extrapolate
- if len_dataloader is not None:
- num_update_steps_per_epoch = max(
- len_dataloader // args.gradient_accumulation_steps
- + int(len_dataloader % args.gradient_accumulation_steps > 0),
- 1,
- )
- # Case 3: We have a length but are using epochs, we can extrapolate the number of steps
- if epoch_based:
- max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
- # Now we figure out `num_examples`, `num_train_epochs`, and `train_samples`
- if len_dataloader:
- num_examples = self.num_examples(dataloader)
- if args.max_steps > 0:
- num_train_epochs = max_steps // num_update_steps_per_epoch + int(
- max_steps % num_update_steps_per_epoch > 0
- )
- # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
- # the best we can do.
- num_train_samples = max_steps * total_train_batch_size
- else:
- num_train_epochs = math.ceil(args.num_train_epochs)
- num_train_samples = self.num_examples(dataloader) * args.num_train_epochs
- elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size
- # Setting a very large number of epochs so we go as many times as necessary over the iterator.
- num_train_epochs = sys.maxsize
- num_update_steps_per_epoch = max_steps
- num_examples = total_train_batch_size * args.max_steps
- num_train_samples = args.max_steps * total_train_batch_size
- else:
- raise ValueError(
- "args.max_steps must be set to a positive value if dataloader does not have a length, was"
- f" {args.max_steps}"
- )
- steps_in_epoch = len_dataloader if len_dataloader is not None else max_steps * args.gradient_accumulation_steps
- return (
- num_train_epochs,
- num_update_steps_per_epoch,
- num_examples,
- num_train_samples,
- total_train_batch_size,
- steps_in_epoch,
- max_steps,
- )
- def get_total_train_batch_size(self, args: TrainingArguments) -> int:
- """Calculates total batch size (micro_batch * grad_accum * dp_world_size).
- Accounts for all parallelism dimensions: TP, CP, and SP.
- Formula: dp_world_size = world_size // (tp_size * cp_size * sp_size)
- Where:
- - TP (Tensor Parallelism): Model layers split across GPUs
- - CP (Context Parallelism): Sequences split using Ring Attention (FSDP2)
- - SP (Sequence Parallelism): Sequences split using ALST/Ulysses (DeepSpeed)
- All dimensions are separate and multiplicative: world_size = dp_size * tp_size * cp_size * sp_size
- """
- dp_world_size = args.world_size // self.get_tp_size() // self.get_cp_size() // self.get_sp_size()
- return self._train_batch_size * args.gradient_accumulation_steps * dp_world_size
- def get_sp_size(self) -> int:
- """Get the sequence parallel size"""
- if getattr(self.accelerator, "parallelism_config", None) is None:
- return 1
- else:
- pc = self.accelerator.parallelism_config
- return pc.sp_size
- def get_cp_size(self) -> int:
- """Get the context parallel size"""
- if getattr(self.accelerator, "parallelism_config", None) is None:
- return 1
- else:
- pc = self.accelerator.parallelism_config
- return pc.cp_size
- def get_tp_size(self) -> int:
- """Get the tensor parallel size from either the model or DeepSpeed config."""
- # 1. Check model.tp_size first
- if (model_tp := getattr(self.model, "_tp_size", None)) is not None:
- return model_tp
- # 2. Fall back to DeepSpeed config if enabled
- if self.is_deepspeed_enabled and (deepspeed_config := getattr(self.args, "hf_deepspeed_config", None)):
- return deepspeed_config.config.get("tensor_parallel", {}).get("autotp_size", 1)
- # 3. Default fallback
- return 1
- def _wrap_model(self, model: nn.Module, training: bool = True, dataloader: DataLoader | None = None) -> nn.Module:
- """Wrap `model` for distributed training if needed (DDP, FSDP, SageMaker, etc.)."""
- # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
- if self.accelerator.unwrap_model(model, keep_torch_compile=False) is not model:
- return model
- if is_sagemaker_mp_enabled():
- # Wrapping the base model twice in a DistributedModel will raise an error.
- if isinstance(model, smp.model.DistributedModel):
- return model
- return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)
- # Multi-gpu training, 8bit models does not support DP
- if self.args.n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False):
- model = nn.DataParallel(model)
- # Note: in torch.distributed mode, there's no point in wrapping the model
- # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
- if not training:
- return model
- # Distributed training using PyTorch FSDP
- if self.is_fsdp_xla_enabled:
- model = wrap_model_xla_fsdp(model, self.args, self.is_fsdp_xla_v2_enabled)
- elif is_sagemaker_dp_enabled():
- model = nn.parallel.DistributedDataParallel(
- model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
- )
- return model
- def _update_auto_batch_size(self, batch_size):
- """Free memory, reset model wrapping, and update DeepSpeed config for the new batch size when using `auto_find_batch_size`"""
- # `_train_batch_size` value might have changed to `auto_find_batch_size`
- self._train_batch_size = batch_size
- # frees the wrapped model and resets it back to the unwrapped base model
- release_memory(self.model_wrapped)
- if self.is_fsdp_enabled:
- # Remove FSDP wrapping from sub-models because self.model points to the wrapped model in FSDP case
- self.model = unwrap_model(self.model, recursive=True)
- self.model_wrapped = self.model
- # Check for DeepSpeed *after* the initial pass and modify the config
- if self.is_deepspeed_enabled:
- # Temporarily unset `self.args.train_batch_size`
- original_bs = self.args.per_device_train_batch_size
- self.args.per_device_train_batch_size = self._train_batch_size // max(1, self.args.n_gpu)
- propagate_args_to_deepspeed(self.accelerator, self.args, auto_find_batch_size=True)
- self.args.per_device_train_batch_size = original_bs
- def _track_num_input_tokens(self, inputs):
- """Count input tokens seen (all or non-padding) and update state."""
- if self.args.include_num_input_tokens_seen == "no":
- return
- main_input_name = getattr(self.model, "main_input_name", "input_ids")
- if main_input_name not in inputs:
- logger.warning(
- "Tried to track the number of tokens seen, however the current model is "
- "not configured properly to know what item is the input. To fix this, add "
- "a `main_input_name` attribute to the model class you are using."
- )
- return
- if self.args.include_num_input_tokens_seen == "non_padding":
- if "attention_mask" in inputs:
- input_tokens = inputs["attention_mask"].sum()
- elif (
- self.processing_class is not None
- and hasattr(self.processing_class, "pad_token_id")
- and self.processing_class.pad_token_id is not None
- ):
- input_tokens = (inputs[main_input_name] != self.processing_class.pad_token_id).sum()
- else:
- logger.warning(
- "Could not determine method to count non-padding tokens, falling back to counting all tokens."
- )
- input_tokens = inputs[main_input_name].numel()
- else:
- input_tokens = inputs[main_input_name].numel()
- input_tokens = torch.as_tensor(input_tokens, device=self.args.device, dtype=torch.int64)
- self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).sum().item()
- def _clip_grad_norm(self, model):
- """Clip gradients to max_grad_norm. Returns the pre-clip gradient norm."""
- if is_sagemaker_mp_enabled() and self.args.fp16:
- return self.optimizer.clip_master_grads(self.args.max_grad_norm)
- return self.accelerator.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
- def _get_grad_norm(self, model, grad_norm=None):
- """Return the gradient norm as a Python float."""
- if grad_norm is None:
- # Compute norm without clipping (inf means no actual clipping happens)
- grad_norm = self.accelerator.clip_grad_norm_(model.parameters(), float("inf"))
- if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
- if hasattr(grad_norm, "item"):
- grad_norm = grad_norm.item()
- return grad_norm
- # ---- Evaluation & Prediction ----
- def evaluate(
- self,
- eval_dataset: Dataset | dict[str, Dataset] | None = None,
- ignore_keys: list[str] | None = None,
- metric_key_prefix: str = "eval",
- ) -> dict[str, float]:
- """
- Run evaluation and returns metrics.
- The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
- (pass it to the init `compute_metrics` argument).
- You can also subclass and override this method to inject custom behavior.
- Args:
- eval_dataset (`Dataset` | dict[str, `Dataset`], *optional*):
- Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns
- not accepted by the `model.forward()` method are automatically removed. If it is a dictionary, it will
- evaluate on each dataset, prepending the dictionary key to the metric name. Datasets must implement the
- `__len__` method.
- <Tip>
- If you pass a dictionary with names of datasets as keys and datasets as values, evaluate will run
- separate evaluations on each dataset. This can be useful to monitor how training affects other
- datasets or simply to get a more fine-grained evaluation.
- When used with `load_best_model_at_end`, make sure `metric_for_best_model` references exactly one
- of the datasets. If you, for example, pass in `{"data1": data1, "data2": data2}` for two datasets
- `data1` and `data2`, you could specify `metric_for_best_model="eval_data1_loss"` for using the
- loss on `data1` and `metric_for_best_model="eval_data2_loss"` for the loss on `data2`.
- </Tip>
- ignore_keys (`list[str]`, *optional*):
- A list of keys in the output of your model (if it is a dictionary) that should be ignored when
- gathering predictions.
- metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
- An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
- "eval_bleu" if the prefix is "eval" (default)
- Returns:
- A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
- dictionary also contains the epoch number which comes from the training state.
- """
- # handle multiple eval datasets
- override = eval_dataset is not None
- eval_dataset = eval_dataset if override else self.eval_dataset
- if isinstance(eval_dataset, dict):
- metrics = {}
- for eval_dataset_name, _eval_dataset in eval_dataset.items():
- dataset_metrics = self.evaluate(
- eval_dataset=_eval_dataset if override else eval_dataset_name,
- ignore_keys=ignore_keys,
- metric_key_prefix=f"{metric_key_prefix}_{eval_dataset_name}",
- )
- metrics.update(dataset_metrics)
- return metrics
- # memory metrics - must set up as early as possible
- self._memory_tracker.start()
- eval_dataloader = self.get_eval_dataloader(eval_dataset)
- if self.is_fsdp_xla_v2_enabled:
- eval_dataloader = tpu_spmd_dataloader(eval_dataloader)
- start_time = time.time()
- output = self.evaluation_loop(
- eval_dataloader,
- description="Evaluation",
- # No point gathering the predictions if there are no metrics, otherwise we defer to
- # self.args.prediction_loss_only
- prediction_loss_only=True if self.compute_metrics is None else None,
- ignore_keys=ignore_keys,
- metric_key_prefix=metric_key_prefix,
- )
- total_batch_size = self.args.eval_batch_size * self.args.world_size
- if f"{metric_key_prefix}_model_preparation_time" in output.metrics:
- start_time += output.metrics[f"{metric_key_prefix}_model_preparation_time"]
- output.metrics.update(
- speed_metrics(
- metric_key_prefix,
- start_time,
- num_samples=output.num_samples,
- num_steps=math.ceil(output.num_samples / total_batch_size),
- )
- )
- self.log(output.metrics)
- if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
- xm.master_print(met.metrics_report())
- self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
- self._memory_tracker.stop_and_update_metrics(output.metrics)
- return output.metrics
- def evaluation_loop(
- self,
- dataloader: DataLoader,
- description: str,
- prediction_loss_only: bool | None = None,
- ignore_keys: list[str] | None = None,
- metric_key_prefix: str = "eval",
- ) -> EvalLoopOutput:
- """
- Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
- Works both with or without labels.
- """
- args = self.args
- prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
- # if eval is called w/o train, handle model prep here
- if self.is_deepspeed_enabled and self.deepspeed is None:
- _, _ = deepspeed_init(self, num_training_steps=0, inference=True)
- model = self._wrap_model(self.model, training=False)
- if len(self.accelerator._models) == 0 and model is self.model:
- start_time = time.time()
- model = (
- self.accelerator.prepare(model)
- if self.is_deepspeed_enabled or (self.is_fsdp_enabled and not self.args.torch_compile)
- else self.accelerator.prepare_model(model, evaluation_mode=True)
- )
- self.model_preparation_time = round(time.time() - start_time, 4)
- if self.is_fsdp_enabled:
- self.model = model
- # for the rest of this function `model` is the outside model, whether it was wrapped or not
- if model is not self.model:
- self.model_wrapped = model
- # backward compatibility
- if self.is_deepspeed_enabled:
- self.deepspeed = self.model_wrapped
- # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
- # while ``train`` is running, cast it to the right dtype first and then put on device
- if not self.is_in_train:
- if args.fp16_full_eval:
- model = model.to(dtype=torch.float16, device=args.device)
- elif args.bf16_full_eval:
- model = model.to(dtype=torch.bfloat16, device=args.device)
- batch_size = self.args.eval_batch_size
- logger.info(f"\n***** Running {description} *****")
- if has_length(dataloader):
- logger.info(f" Num examples = {self.num_examples(dataloader)}")
- else:
- logger.info(" Num examples: Unknown")
- logger.info(f" Batch size = {batch_size}")
- if hasattr(model, "eval") and callable(model.eval):
- model.eval()
- if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval):
- self.optimizer.eval()
- self.callback_handler.eval_dataloader = dataloader
- # Do this before wrapping.
- eval_dataset = getattr(dataloader, "dataset", None)
- # Initialize containers
- all_losses = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
- all_preds = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
- all_labels = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
- all_inputs = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
- metrics = None
- eval_set_kwargs = {}
- # Will be useful when we have an iterable dataset so don't know its length.
- observed_num_examples = 0
- # Main evaluation loop
- for step, inputs in enumerate(dataloader):
- # Update the observed num examples
- observed_batch_size = find_batch_size(inputs)
- if observed_batch_size is not None:
- observed_num_examples += observed_batch_size
- # For batch samplers, batch_size is not known by the dataloader in advance.
- if batch_size is None:
- batch_size = observed_batch_size
- # Prediction step
- losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
- main_input_name = getattr(self.model, "main_input_name", "input_ids")
- inputs_decode = (
- self._prepare_input(inputs[main_input_name]) if "inputs" in args.include_for_metrics else None
- )
- if is_torch_xla_available():
- xm.mark_step()
- # Update containers
- if losses is not None:
- losses = self.gather_function(losses.repeat(batch_size))
- all_losses.add(losses)
- if inputs_decode is not None:
- inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100)
- inputs_decode = self.gather_function(inputs_decode)
- if not self.args.batch_eval_metrics or description == "Prediction":
- all_inputs.add(inputs_decode)
- if labels is not None:
- # Pad labels here, preparing for preprocess_logits_for_metrics in next logits block.
- labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
- if logits is not None:
- logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)
- if self.preprocess_logits_for_metrics is not None:
- logits = self.preprocess_logits_for_metrics(logits, labels)
- logits = self.gather_function(logits)
- if not self.args.batch_eval_metrics or description == "Prediction":
- all_preds.add(logits)
- if labels is not None:
- labels = self.gather_function(labels)
- if not self.args.batch_eval_metrics or description == "Prediction":
- all_labels.add(labels)
- self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
- if self.args.batch_eval_metrics:
- if self.compute_metrics is not None and logits is not None and labels is not None:
- is_last_step = self.accelerator.gradient_state.end_of_dataloader
- batch_kwargs = {}
- batch_kwargs["losses"] = losses if "loss" in args.include_for_metrics else None
- batch_kwargs["inputs"] = inputs if "inputs" in args.include_for_metrics else None
- metrics = self.compute_metrics(
- EvalPrediction(predictions=logits, label_ids=labels, **batch_kwargs),
- compute_result=is_last_step,
- )
- del losses, logits, labels, inputs
- torch.cuda.empty_cache()
- # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
- elif args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
- all_losses.to_cpu_and_numpy()
- all_preds.to_cpu_and_numpy()
- all_labels.to_cpu_and_numpy()
- all_inputs.to_cpu_and_numpy()
- del losses, logits, labels, inputs
- torch.cuda.empty_cache()
- # After all calls to `.gather_function`, reset to `gather_for_metrics`:
- self.gather_function = self.accelerator.gather_for_metrics
- # Gather all remaining tensors and put them back on the CPU
- all_losses = all_losses.get_arrays()
- all_preds = all_preds.get_arrays()
- all_labels = all_labels.get_arrays()
- all_inputs = all_inputs.get_arrays()
- # Number of samples
- if has_length(eval_dataset):
- num_samples = len(eval_dataset)
- # The instance check is weird and does not actually check for the type, but whether the dataset has the right
- # methods. Therefore we need to make sure it also has the attribute.
- elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0:
- num_samples = eval_dataset.num_examples
- else:
- if has_length(dataloader):
- num_samples = self.num_examples(dataloader)
- else: # both len(dataloader.dataset) and len(dataloader) fail
- num_samples = observed_num_examples
- if num_samples == 0 and observed_num_examples > 0:
- num_samples = observed_num_examples
- # Metrics!
- if (
- self.compute_metrics is not None
- and all_preds is not None
- and all_labels is not None
- and not self.args.batch_eval_metrics
- ):
- eval_set_kwargs["losses"] = all_losses if "loss" in args.include_for_metrics else None
- eval_set_kwargs["inputs"] = all_inputs if "inputs" in args.include_for_metrics else None
- metrics = self.compute_metrics(
- EvalPrediction(predictions=all_preds, label_ids=all_labels, **eval_set_kwargs)
- )
- elif metrics is None:
- metrics = {}
- # To be JSON-serializable, we need to remove numpy types or zero-d tensors
- metrics = denumpify_detensorize(metrics)
- if isinstance(all_losses, list) and all_losses:
- metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item()
- elif isinstance(all_losses, np.ndarray):
- metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
- if hasattr(self, "model_preparation_time"):
- metrics[f"{metric_key_prefix}_model_preparation_time"] = self.model_preparation_time
- # Prefix all keys with metric_key_prefix + '_'
- for key in list(metrics.keys()):
- if not key.startswith(f"{metric_key_prefix}_"):
- metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
- return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)
- def predict(
- self, test_dataset: Dataset, ignore_keys: list[str] | None = None, metric_key_prefix: str = "test"
- ) -> PredictionOutput:
- """
- Run prediction and returns predictions and potential metrics.
- Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
- will also return metrics, like in `evaluate()`.
- Args:
- test_dataset (`Dataset`):
- Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the
- `model.forward()` method are automatically removed. Has to implement the method `__len__`
- ignore_keys (`list[str]`, *optional*):
- A list of keys in the output of your model (if it is a dictionary) that should be ignored when
- gathering predictions.
- metric_key_prefix (`str`, *optional*, defaults to `"test"`):
- An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
- "test_bleu" if the prefix is "test" (default)
- <Tip>
- If your predictions or labels have different sequence length (for instance because you're doing dynamic padding
- in a token classification task) the predictions will be padded (on the right) to allow for concatenation into
- one array. The padding index is -100.
- </Tip>
- Returns: *NamedTuple* A namedtuple with the following keys:
- - predictions (`np.ndarray`): The predictions on `test_dataset`.
- - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
- - metrics (`dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
- labels).
- """
- # memory metrics - must set up as early as possible
- self._memory_tracker.start()
- test_dataloader = self.get_test_dataloader(test_dataset)
- start_time = time.time()
- output = self.evaluation_loop(
- test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
- )
- total_batch_size = self.args.eval_batch_size * self.args.world_size
- if f"{metric_key_prefix}_model_preparation_time" in output.metrics:
- start_time += output.metrics[f"{metric_key_prefix}_model_preparation_time"]
- output.metrics.update(
- speed_metrics(
- metric_key_prefix,
- start_time,
- num_samples=output.num_samples,
- num_steps=math.ceil(output.num_samples / total_batch_size),
- )
- )
- self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics)
- self._memory_tracker.stop_and_update_metrics(output.metrics)
- return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics)
- def prediction_step(
- self,
- model: nn.Module,
- inputs: dict[str, torch.Tensor | Any],
- prediction_loss_only: bool,
- ignore_keys: list[str] | None = None,
- ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
- """
- Perform an evaluation step on `model` using `inputs`.
- Subclass and override to inject custom behavior.
- Args:
- model (`nn.Module`):
- The model to evaluate.
- inputs (`dict[str, torch.Tensor | Any]`):
- The inputs and targets of the model.
- The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
- argument `labels`. Check your model's documentation for all accepted arguments.
- prediction_loss_only (`bool`):
- Whether or not to return the loss only.
- ignore_keys (`list[str]`, *optional*):
- A list of keys in the output of your model (if it is a dictionary) that should be ignored when
- gathering predictions.
- Return:
- tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
- logits and labels (each being optional).
- """
- has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
- # For CLIP-like models capable of returning loss values.
- # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
- # is `True` in `model.forward`.
- return_loss = inputs.get("return_loss")
- if return_loss is None:
- return_loss = self.can_return_loss
- loss_without_labels = len(self.label_names) == 0 and return_loss
- inputs = self._prepare_inputs(inputs)
- if ignore_keys is None:
- if hasattr(self.model, "config"):
- ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", ["past_key_values"])
- else:
- ignore_keys = []
- # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
- if has_labels or loss_without_labels:
- labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
- if len(labels) == 1:
- labels = labels[0]
- else:
- labels = None
- with torch.no_grad():
- if is_sagemaker_mp_enabled():
- raw_outputs = smp_forward_only(model, inputs)
- if has_labels or loss_without_labels:
- if isinstance(raw_outputs, dict):
- loss_mb = raw_outputs["loss"]
- logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
- else:
- loss_mb = raw_outputs[0]
- logits_mb = raw_outputs[1:]
- loss = loss_mb.reduce_mean().detach().cpu()
- logits = smp_nested_concat(logits_mb)
- else:
- loss = None
- if isinstance(raw_outputs, dict):
- logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
- else:
- logits_mb = raw_outputs
- logits = smp_nested_concat(logits_mb)
- else:
- if has_labels or loss_without_labels:
- with self.compute_loss_context_manager():
- num_items_in_batch = self._get_num_items_in_batch([inputs], self.args.device)
- loss, outputs = self.compute_loss(
- model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch
- )
- loss = loss.detach().mean()
- if isinstance(outputs, dict):
- logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
- else:
- logits = outputs[1:]
- else:
- loss = None
- with self.compute_loss_context_manager():
- outputs = model(**inputs)
- if isinstance(outputs, dict):
- logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
- else:
- logits = outputs
- if prediction_loss_only:
- return (loss, None, None)
- logits = nested_detach(logits)
- if len(logits) == 1:
- logits = logits[0]
- return (loss, logits, labels)
- def _evaluate(
- self,
- trial: "optuna.Trial | dict[str, Any] | None",
- ignore_keys_for_eval: list[str] | None,
- skip_scheduler: bool = False,
- ) -> dict[str, float]:
- """Run evaluation, report to HP search, and step ReduceLROnPlateau/GreedyLR if needed."""
- metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
- self._report_to_hp_search(trial, self.state.global_step, metrics)
- # Run delayed LR scheduler now that metrics are populated
- if (
- isinstance(self.lr_scheduler, (torch.optim.lr_scheduler.ReduceLROnPlateau, GreedyLR))
- and not skip_scheduler
- ):
- metric_to_check = self.args.metric_for_best_model
- if not metric_to_check.startswith("eval_"):
- metric_to_check = f"eval_{metric_to_check}"
- try:
- self.lr_scheduler.step(metrics[metric_to_check])
- except KeyError as exc:
- raise KeyError(
- f"The `metric_for_best_model` training argument is set to '{metric_to_check}', "
- f"which is not found in the evaluation metrics. "
- f"The available evaluation metrics are: {list(metrics.keys())}. "
- f"Please ensure that the `compute_metrics` function returns a dictionary that includes '{metric_to_check}' or "
- f"consider changing the `metric_for_best_model` via the TrainingArguments."
- ) from exc
- return metrics
- # ---- Checkpoint Saving ----
- def _get_output_dir(self, trial: "optuna.Trial | dict[str, Any] | None") -> str:
- """Return the output directory, accounting for hyperparameter search trials."""
- if self.hp_search_backend is not None and trial is not None:
- if self.hp_search_backend == HPSearchBackend.OPTUNA:
- run_id = trial.number
- elif self.hp_search_backend == HPSearchBackend.RAY:
- import ray.tune
- run_id = ray.tune.get_context().get_trial_id()
- elif self.hp_search_backend == HPSearchBackend.WANDB:
- import wandb
- run_id = wandb.run.id
- run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
- run_dir = os.path.join(self.args.output_dir, run_name)
- else:
- run_dir = self.args.output_dir
- return run_dir
- def _save_checkpoint(self, model: nn.Module, trial: "optuna.Trial | dict[str, Any] | None") -> None:
- """Save model checkpoint, optimizer, scheduler, scaler, RNG states, and trainer state."""
- # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
- # want to save except FullyShardedDDP.
- # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
- # Save model checkpoint
- checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
- if self.hp_search_backend is None and trial is None:
- self.store_flos()
- run_dir = self._get_output_dir(trial=trial)
- output_dir = os.path.join(run_dir, checkpoint_folder)
- self.save_model(output_dir, _internal_call=True)
- if (
- self.args.save_strategy in [SaveStrategy.STEPS, SaveStrategy.EPOCH, SaveStrategy.BEST]
- and self.state.best_global_step
- ):
- # Wait for everyone to get here so we are sure the model has been saved by process 0
- # before we check if the best_checkpoint_dir exists
- if is_torch_xla_available():
- xm.rendezvous("load_best_model_at_end")
- elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
- dist.barrier()
- elif is_sagemaker_mp_enabled():
- smp.barrier()
- best_checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.best_global_step}"
- best_checkpoint_dir = os.path.join(run_dir, best_checkpoint_folder)
- if os.path.exists(best_checkpoint_dir):
- self.state.best_model_checkpoint = best_checkpoint_dir
- if not self.args.save_only_model:
- # Save optimizer and scheduler
- self._save_optimizer_and_scheduler(output_dir)
- self._save_scaler(output_dir)
- # Save RNG state
- self._save_rng_state(output_dir)
- # Save the Trainer state
- if self.args.should_save:
- # Update `ExportableState` callbacks and `TrainerControl` state to where we are currently
- for cb in [
- cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
- ]:
- cb_name = cb.__class__.__name__
- cb_state = cb.state()
- if isinstance(self.state.stateful_callbacks[cb_name], list):
- self.state.stateful_callbacks[cb_name].append(cb_state)
- else:
- self.state.stateful_callbacks[cb_name] = cb_state
- self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
- if self.args.push_to_hub:
- self._push_from_checkpoint(output_dir)
- # Maybe delete some older checkpoints.
- if self.args.should_save:
- # we use mtime as default, filesystems without mtime support will be detected in `sort_checkpoints`
- rotate_checkpoints(
- output_dir=run_dir,
- save_total_limit=self.args.save_total_limit,
- best_model_checkpoint=self.state.best_model_checkpoint,
- use_mtime=True,
- )
- def _determine_best_metric(self, metrics: dict[str, float], trial: "optuna.Trial | dict[str, Any] | None") -> bool:
- """
- Determine if the model should be saved based on the evaluation metrics.
- Returns:
- bool: True if a new best metric was found, else False
- """
- is_new_best_metric = False
- if self.args.metric_for_best_model is not None:
- metric_to_check = self.args.metric_for_best_model
- if not metric_to_check.startswith("eval_"):
- metric_to_check = f"eval_{metric_to_check}"
- try:
- metric_value = metrics[metric_to_check]
- except KeyError as exc:
- raise KeyError(
- f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. "
- f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments."
- ) from exc
- operator = np.greater if self.args.greater_is_better else np.less
- if self.state.best_metric is None:
- self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf")
- if operator(metric_value, self.state.best_metric):
- self.state.best_metric = metric_value
- if self.args.save_strategy in [SaveStrategy.STEPS, SaveStrategy.EPOCH, SaveStrategy.BEST]:
- self.state.best_global_step = self.state.global_step
- is_new_best_metric = True
- return is_new_best_metric
- def _save_rng_state(self, output_dir: str) -> None:
- """Save random number generator states for reproducible resumption."""
- # Save RNG state in non-distributed training
- rng_states = {
- "python": random.getstate(),
- "numpy": np.random.get_state(),
- "cpu": torch.random.get_rng_state(),
- }
- if torch.cuda.is_available():
- if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
- # In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
- rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
- else:
- rng_states["cuda"] = torch.cuda.random.get_rng_state()
- if is_torch_xla_available():
- rng_states["xla"] = xm.get_rng_state()
- if is_torch_npu_available():
- if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
- rng_states["npu"] = torch.npu.random.get_rng_state_all()
- else:
- rng_states["npu"] = torch.npu.random.get_rng_state()
- if is_torch_hpu_available():
- if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
- rng_states["hpu"] = torch.hpu.random.get_rng_state_all()
- else:
- rng_states["hpu"] = torch.hpu.random.get_rng_state()
- if is_torch_mlu_available():
- if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
- rng_states["mlu"] = torch.mlu.random.get_rng_state_all()
- else:
- rng_states["mlu"] = torch.mlu.random.get_rng_state()
- if is_torch_musa_available():
- if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
- rng_states["musa"] = torch.musa.get_rng_state_all()
- else:
- rng_states["musa"] = torch.musa.get_rng_state()
- # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
- # not yet exist.
- os.makedirs(output_dir, exist_ok=True)
- if self.args.world_size <= 1:
- torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
- else:
- torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
- def _save_optimizer_and_scheduler(self, output_dir: str) -> None:
- """Save optimizer and learning rate scheduler states to `output_dir`."""
- if is_torch_xla_available():
- xm.rendezvous("saving_optimizer_states")
- if self.is_fsdp_xla_v1_enabled:
- optm = {
- "optimizer": self.optimizer.state_dict(),
- "shard_metadata": self.model.get_shard_metadata(),
- }
- xm.save(
- optm,
- os.path.join(
- output_dir, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}"
- ),
- master_only=False,
- )
- else:
- xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
- with warnings.catch_warnings(record=True) as caught_warnings:
- xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
- reissue_pt_warnings(caught_warnings)
- elif is_sagemaker_mp_enabled():
- opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
- smp.barrier()
- if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
- smp.save(
- opt_state_dict,
- os.path.join(output_dir, OPTIMIZER_NAME),
- partial=True,
- v3=smp.state.cfg.shard_optimizer_state,
- )
- elif self.is_deepspeed_enabled:
- # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
- # config `stage3_gather_16bit_weights_on_model_save` is True
- accept_exclude_frozen_parameters = "exclude_frozen_parameters" in set(
- inspect.signature(self.model_wrapped.save_checkpoint).parameters.keys()
- )
- if accept_exclude_frozen_parameters and _is_peft_model(self.model):
- self.model_wrapped.save_checkpoint(output_dir, exclude_frozen_parameters=True)
- else:
- self.model_wrapped.save_checkpoint(output_dir)
- elif self.is_fsdp_enabled:
- # save fsdp specific ckpt for resuming from ckpt
- save_fsdp_model(
- self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir, **get_fsdp_ckpt_kwargs()
- )
- save_fsdp_optimizer(
- self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir
- )
- elif self.args.should_save:
- # deepspeed.save_checkpoint above saves model/optim/sched
- torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
- # Save SCHEDULER & SCALER
- is_deepspeed_custom_scheduler = self.is_deepspeed_enabled and not isinstance(
- self.lr_scheduler, DeepSpeedSchedulerWrapper
- )
- if (
- self.args.should_save
- and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler)
- and not is_torch_xla_available()
- ):
- with warnings.catch_warnings(record=True) as caught_warnings:
- torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
- reissue_pt_warnings(caught_warnings)
- def _save_scaler(self, output_dir: str) -> None:
- """Save the gradient scaler state if one exists."""
- # See if there is a scaler attribute
- try:
- scaler = self.accelerator.scaler
- except AttributeError:
- return
- if scaler is None:
- return
- if is_torch_xla_available():
- xm.rendezvous("saving_scaler_state")
- with warnings.catch_warnings(record=True) as caught_warnings:
- xm.save(self.accelerator.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
- reissue_pt_warnings(caught_warnings)
- # Save SCALER
- if self.args.should_save and not is_torch_xla_available():
- with warnings.catch_warnings(record=True) as caught_warnings:
- torch.save(self.accelerator.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
- reissue_pt_warnings(caught_warnings)
- # ---- Checkpoint Resuming ----
- def _load_from_checkpoint(self, resume_from_checkpoint: str, model: nn.Module | None = None) -> None:
- """Load model weights from a checkpoint directory."""
- if model is None:
- model = self.model
- config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME)
- adapter_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_WEIGHTS_NAME)
- adapter_safe_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)
- weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME)
- weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
- safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
- safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)
- is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and (
- # this checks the FSDP state dict when `SHARDED_STATE_DICT` is used
- any(
- FSDP_MODEL_NAME in folder_name
- for folder_name in os.listdir(resume_from_checkpoint)
- if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
- )
- # this checks the FSDP state dict when `FULL_STATE_DICT` is used
- or os.path.isfile(os.path.join(resume_from_checkpoint, f"{FSDP_MODEL_NAME}.bin"))
- )
- # if multiple adapters exist, they get saved in sub directories
- adapter_subdirs = (
- [
- folder_name
- for folder_name in os.listdir(resume_from_checkpoint)
- if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
- and (
- os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_WEIGHTS_NAME))
- or os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_SAFE_WEIGHTS_NAME))
- )
- ]
- if os.path.isdir(resume_from_checkpoint)
- else []
- )
- if is_fsdp_ckpt and not self.is_fsdp_enabled:
- raise ValueError(f"Checkpoint found at {resume_from_checkpoint} is only supported when using PyTorch FSDP")
- if not (
- any(
- os.path.isfile(f)
- for f in [
- weights_file,
- safe_weights_file,
- weights_index_file,
- safe_weights_index_file,
- adapter_weights_file,
- adapter_safe_weights_file,
- ]
- )
- or is_fsdp_ckpt
- or adapter_subdirs
- ):
- raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
- logger.info(f"Loading model from {resume_from_checkpoint}.")
- if os.path.isfile(config_file):
- config = PreTrainedConfig.from_json_file(config_file)
- checkpoint_version = config.transformers_version
- if checkpoint_version is not None and checkpoint_version != __version__:
- logger.warning(
- f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
- f"Transformers but your current version is {__version__}. This is not recommended and could "
- "yield to errors or unwanted behaviors."
- )
- if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file) or is_fsdp_ckpt:
- # If the model is on the GPU, it still works!
- if is_sagemaker_mp_enabled():
- smp.resume_from_checkpoint(
- path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False
- )
- elif self.is_fsdp_enabled:
- load_fsdp_model(
- self.accelerator.state.fsdp_plugin,
- self.accelerator,
- model,
- resume_from_checkpoint,
- **get_fsdp_ckpt_kwargs(),
- )
- else:
- # We load the model state dict on the CPU to avoid an OOM error.
- if os.path.isfile(safe_weights_file):
- state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu")
- else:
- check_torch_load_is_safe()
- state_dict = torch.load(weights_file, map_location="cpu", weights_only=True)
- # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
- # which takes *args instead of **kwargs
- load_result = model.load_state_dict(state_dict, False)
- # release memory
- del state_dict
- self._issue_warnings_after_load(load_result)
- # Load adapters following PR # 24096
- elif _is_peft_model(model):
- # If training a model using PEFT, assume that adapter have been saved properly.
- if hasattr(model, "active_adapters") and hasattr(model, "load_adapter"):
- if os.path.exists(resume_from_checkpoint):
- active_adapters = model.active_adapters
- if len(active_adapters) > 1:
- logger.warning("Multiple active adapters detected will only consider the first adapter")
- active_adapter = active_adapters[0]
- if adapter_subdirs:
- for subdir_name in adapter_subdirs:
- peft_id = os.path.join(resume_from_checkpoint, subdir_name)
- model.load_adapter(peft_id, subdir_name, is_trainable=(subdir_name == active_adapter))
- model.set_adapter(active_adapter)
- else:
- model.load_adapter(resume_from_checkpoint, active_adapter, is_trainable=True)
- else:
- logger.warning(
- "The intermediate checkpoints of PEFT may not be saved correctly, "
- f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. "
- "Check some examples here: https://github.com/huggingface/peft/issues/96"
- )
- else:
- logger.warning(f"Could not load adapter model, make sure to have PEFT >= {MIN_PEFT_VERSION} installed")
- else:
- # We load the sharded checkpoint
- load_result = load_sharded_checkpoint(model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled())
- if not is_sagemaker_mp_enabled():
- self._issue_warnings_after_load(load_result)
- def _load_best_model(self) -> None:
- """Load the best model found during training based on the tracked metric."""
- logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
- best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
- best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME)
- best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_WEIGHTS_NAME)
- best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)
- model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
- if self.is_deepspeed_enabled:
- deepspeed_load_checkpoint(
- self.model_wrapped,
- self.state.best_model_checkpoint,
- load_module_strict=not _is_peft_model(self.model),
- )
- elif self.is_fsdp_enabled:
- load_result = load_fsdp_model(
- self.accelerator.state.fsdp_plugin,
- self.accelerator,
- model,
- self.state.best_model_checkpoint,
- **get_fsdp_ckpt_kwargs(),
- )
- elif (
- os.path.exists(best_model_path)
- or os.path.exists(best_safe_model_path)
- or os.path.exists(best_adapter_model_path)
- or os.path.exists(best_safe_adapter_model_path)
- ):
- has_been_loaded = True
- if is_sagemaker_mp_enabled():
- smp.resume_from_checkpoint(
- path=self.state.best_model_checkpoint,
- tag=WEIGHTS_NAME,
- partial=False,
- load_optimizer=False,
- )
- else:
- if _is_peft_model(model):
- # If training a model using PEFT, assume that adapter have been saved properly.
- if hasattr(model, "active_adapters") and hasattr(model, "load_adapter"):
- active_adapter = model.active_adapters[0]
- if len(model.active_adapters) > 1:
- logger.warning("Detected multiple active adapters, will only consider the first one")
- if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path):
- try:
- model.load_adapter(self.state.best_model_checkpoint, active_adapter)
- except RuntimeError as exc:
- if model.peft_config[active_adapter].is_prompt_learning:
- # for context: https://github.com/huggingface/peft/issues/2256
- msg = (
- "When using prompt learning PEFT methods such as "
- f"{model.peft_config[active_adapter].peft_type.value}, setting "
- "load_best_model_at_end=True can lead to errors, it is recommended "
- "to set this to False and to load the model manually from the checkpoint "
- "directory using PeftModel.from_pretrained(base_model, <path>) after training "
- "has finished."
- )
- raise RuntimeError(msg) from exc
- else:
- raise
- # Load_adapter has no return value present, modify it when appropriate.
- from torch.nn.modules.module import _IncompatibleKeys
- load_result = _IncompatibleKeys([], [])
- else:
- logger.warning(
- "The intermediate checkpoints of PEFT may not be saved correctly, "
- f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. "
- "Check some examples here: https://github.com/huggingface/peft/issues/96"
- )
- has_been_loaded = False
- else:
- logger.warning(
- f"Could not load adapter model, make sure to have PEFT >= {MIN_PEFT_VERSION} installed"
- )
- has_been_loaded = False
- else:
- # We load the model state dict on the CPU to avoid an OOM error.
- if os.path.isfile(best_safe_model_path):
- state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
- else:
- check_torch_load_is_safe()
- state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True)
- # If the model is on the GPU, it still works!
- # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
- # which takes *args instead of **kwargs
- load_result = model.load_state_dict(state_dict, False)
- if not is_sagemaker_mp_enabled() and has_been_loaded:
- self._issue_warnings_after_load(load_result)
- elif os.path.exists(os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_INDEX_NAME)) or os.path.exists(
- os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)
- ):
- load_result = load_sharded_checkpoint(
- model, self.state.best_model_checkpoint, strict=is_sagemaker_mp_enabled()
- )
- if not is_sagemaker_mp_enabled():
- self._issue_warnings_after_load(load_result)
- else:
- logger.warning(
- f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
- "on multiple nodes, you should activate `--save_on_each_node`."
- )
- def _load_rng_state(self, checkpoint: str | None) -> None:
- """Restore random number generator states from a checkpoint."""
- # Load RNG states from `checkpoint`
- if checkpoint is None:
- return
- if self.args.world_size > 1:
- process_index = self.args.process_index
- rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
- if not os.path.isfile(rng_file):
- logger.info(
- f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
- "wasn't launched in a distributed fashion, reproducibility is not guaranteed."
- )
- return
- else:
- rng_file = os.path.join(checkpoint, "rng_state.pth")
- if not os.path.isfile(rng_file):
- logger.info(
- "Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
- "fashion, reproducibility is not guaranteed."
- )
- return
- with safe_globals():
- check_torch_load_is_safe()
- checkpoint_rng_state = torch.load(rng_file, weights_only=True)
- random.setstate(checkpoint_rng_state["python"])
- np.random.set_state(checkpoint_rng_state["numpy"])
- torch.random.set_rng_state(checkpoint_rng_state["cpu"])
- if is_torch_xla_available():
- xm.set_rng_state(checkpoint_rng_state["xla"])
- is_distributed = self.args.parallel_mode == ParallelMode.DISTRIBUTED
- if torch.cuda.is_available():
- set_rng_state_for_device("CUDA", torch.cuda, checkpoint_rng_state, is_distributed)
- if is_torch_npu_available():
- set_rng_state_for_device("NPU", torch.npu, checkpoint_rng_state, is_distributed)
- if is_torch_hpu_available():
- set_rng_state_for_device("HPU", torch.hpu, checkpoint_rng_state, is_distributed)
- if is_torch_mlu_available():
- set_rng_state_for_device("MLU", torch.mlu, checkpoint_rng_state, is_distributed)
- if is_torch_musa_available():
- set_rng_state_for_device("MUSA", torch.musa, checkpoint_rng_state, is_distributed)
- def _load_optimizer_and_scheduler(self, checkpoint: str | None) -> None:
- """If optimizer and scheduler states exist, load them."""
- if checkpoint is None:
- return
- if self.is_deepspeed_enabled:
- # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
- if not isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper):
- with warnings.catch_warnings(record=True) as caught_warnings:
- check_torch_load_is_safe()
- self.lr_scheduler.load_state_dict(
- torch.load(os.path.join(checkpoint, SCHEDULER_NAME), weights_only=True)
- )
- reissue_pt_warnings(caught_warnings)
- return
- checkpoint_file_exists = (
- glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*")
- if is_sagemaker_mp_enabled()
- else (
- os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
- or os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME_BIN))
- or (
- os.path.isdir(checkpoint)
- and any(
- OPTIMIZER_NAME_BIN.split(".")[0] in folder_name
- for folder_name in os.listdir(checkpoint)
- if os.path.isdir(os.path.join(checkpoint, folder_name))
- )
- )
- )
- )
- checkpoint_file_exists = (
- glob.glob(os.path.join(checkpoint, f"rank*-of-{self.args.world_size}-{OPTIMIZER_NAME}"))
- if self.is_fsdp_xla_v1_enabled
- else checkpoint_file_exists
- )
- if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
- # Load in optimizer and scheduler states
- if is_torch_xla_available():
- # On TPU we have to take some extra precautions to properly load the states on the right device.
- if self.is_fsdp_xla_v1_enabled:
- check_torch_load_is_safe()
- optimizer_state = torch.load(
- os.path.join(
- checkpoint, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}"
- ),
- map_location="cpu",
- weights_only=True,
- )
- # We only need `optimizer` when resuming from checkpoint
- optimizer_state = optimizer_state["optimizer"]
- else:
- check_torch_load_is_safe()
- optimizer_state = torch.load(
- os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu", weights_only=True
- )
- with warnings.catch_warnings(record=True) as caught_warnings:
- check_torch_load_is_safe()
- lr_scheduler_state = torch.load(
- os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu", weights_only=True
- )
- reissue_pt_warnings(caught_warnings)
- xm.send_cpu_data_to_device(optimizer_state, self.args.device)
- xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)
- self.optimizer.load_state_dict(optimizer_state)
- self.lr_scheduler.load_state_dict(lr_scheduler_state)
- else:
- if is_sagemaker_mp_enabled():
- def opt_load_hook(mod, opt):
- opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
- self.model_wrapped.register_post_step_hook(opt_load_hook)
- else:
- # We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models.
- # In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more
- # likely to get OOM on CPU (since we load num_gpu times the optimizer state
- map_location = self.args.device if self.args.world_size > 1 else "cpu"
- if self.is_fsdp_enabled:
- load_fsdp_optimizer(
- self.accelerator.state.fsdp_plugin,
- self.accelerator,
- self.optimizer,
- self.model,
- checkpoint,
- **get_fsdp_ckpt_kwargs(),
- )
- else:
- check_torch_load_is_safe()
- self.optimizer.load_state_dict(
- torch.load(
- os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location, weights_only=True
- )
- )
- with warnings.catch_warnings(record=True) as caught_warnings:
- check_torch_load_is_safe()
- self.lr_scheduler.load_state_dict(
- torch.load(os.path.join(checkpoint, SCHEDULER_NAME), weights_only=True)
- )
- reissue_pt_warnings(caught_warnings)
- def _load_scaler(self, checkpoint: str | None) -> None:
- """If scaler state exists, load it."""
- if checkpoint is None:
- return
- checkpoint_file_exists = os.path.isfile(os.path.join(checkpoint, SCALER_NAME))
- if checkpoint_file_exists:
- # On TPU we have to take some extra precautions to properly load the states on the right device.
- # Load in scaler states
- if is_torch_xla_available():
- with warnings.catch_warnings(record=True) as caught_warnings:
- check_torch_load_is_safe()
- scaler_state = torch.load(
- os.path.join(checkpoint, SCALER_NAME), map_location="cpu", weights_only=True
- )
- reissue_pt_warnings(caught_warnings)
- xm.send_cpu_data_to_device(scaler_state, self.args.device)
- self.accelerator.scaler.load_state_dict(scaler_state)
- else:
- with warnings.catch_warnings(record=True) as caught_warnings:
- check_torch_load_is_safe()
- self.accelerator.scaler.load_state_dict(
- torch.load(os.path.join(checkpoint, SCALER_NAME), weights_only=True)
- )
- reissue_pt_warnings(caught_warnings)
- def _load_callback_state(self) -> None:
- """If callback states exist and were passed in, restore their states if enabled"""
- if not self.args.restore_callback_states_from_checkpoint:
- return
- # Callback states are stored in stateful_callbacks
- not_found = []
- new_callbacks = []
- original_callbacks = self.callback_handler.callbacks + [self.control]
- for stored_callback, data in self.state.stateful_callbacks.items():
- if not isinstance(data, list):
- data = [data]
- if any(callback.__class__.__name__ == stored_callback for callback in original_callbacks):
- # We can load/restore from multiple callbacks of the same type.
- duplicates = [
- callback for callback in original_callbacks if callback.__class__.__name__ == stored_callback
- ]
- for callback, callback_data in zip(duplicates, data):
- args = callback_data.get("args", {})
- attributes = callback_data.get("attributes", {})
- new_callback = type(callback)(**args)
- for attribute, value in attributes.items():
- setattr(new_callback, attribute, value)
- if isinstance(callback, TrainerControl):
- # Specifically for restoring the `control` state
- self.control = new_callback
- else:
- new_callbacks.append(new_callback)
- # We remove the existing callback and add it to the list of new callbacks
- self.callback_handler.remove_callback(type(new_callback))
- logger.info("Continuing training from checkpoint, restoring any callbacks that were passed in")
- else:
- not_found.append(stored_callback)
- if len(not_found) > 0:
- logger.warning(
- f"Checkpoint included callbacks not included in current configuration. Ignoring. ({', '.join(not_found)})"
- )
- for callback in new_callbacks:
- self.callback_handler.add_callback(callback)
- def _issue_warnings_after_load(self, load_result: Any) -> None:
- """Log warnings for missing or unexpected keys after loading a checkpoint."""
- if len(load_result.missing_keys) != 0:
- if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set(
- self.model._keys_to_ignore_on_save
- ):
- self.model.tie_weights()
- else:
- logger.warning(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.")
- if len(load_result.unexpected_keys) != 0:
- logger.warning(
- f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
- )
- # ---- Saving & Serialization ----
- def save_model(self, output_dir: str | None = None, _internal_call: bool = False) -> None:
- """
- Will save the model, so you can reload it using `from_pretrained()`.
- Will only save from the main process.
- """
- if output_dir is None:
- output_dir = self.args.output_dir
- if is_torch_xla_available():
- save_tpu_checkpoint(
- self.model, self.args, self.accelerator, self.processing_class, self.is_fsdp_xla_v1_enabled, output_dir
- )
- elif is_sagemaker_mp_enabled():
- # Calling the state_dict needs to be done on the wrapped model and on all processes.
- os.makedirs(output_dir, exist_ok=True)
- state_dict = self.model_wrapped.state_dict()
- if self.args.should_save:
- self._save(output_dir, state_dict=state_dict)
- Path(os.path.join(output_dir, "user_content.pt")).touch()
- elif self.is_fsdp_enabled:
- if "FULL_STATE_DICT" in str(self.accelerator.state.fsdp_plugin.state_dict_type):
- state_dict = self.accelerator.get_state_dict(self.model)
- if self.args.should_save:
- self._save(output_dir, state_dict=state_dict)
- elif self.is_deepspeed_enabled:
- try:
- accept_exclude_frozen_parameters = "exclude_frozen_parameters" in set(
- inspect.signature(self.model_wrapped.save_checkpoint).parameters.keys()
- )
- zero3_sharding = self.deepspeed.config.get("zero_optimization", {}).get("stage", None) == 3
- if accept_exclude_frozen_parameters and _is_peft_model(self.model) and zero3_sharding:
- # When using PEFT with DeepSpeed ZeRO Stage 3,
- # we do not need to load the frozen parameters
- state_dict = self.deepspeed._zero3_consolidated_16bit_state_dict(exclude_frozen_parameters=True)
- else:
- state_dict = self.accelerator.get_state_dict(self.deepspeed)
- if self.args.should_save:
- self._save(output_dir, state_dict=state_dict)
- except ValueError:
- logger.warning(
- " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
- " zero_to_fp32.py to recover weights"
- )
- if self.args.should_save:
- self._save(output_dir, state_dict={})
- # remove the dummy state_dict
- remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
- self.model_wrapped.save_checkpoint(output_dir)
- elif self.args.should_save:
- self._save(output_dir)
- # Push to the Hub when `save_model` is called by the user.
- if self.args.push_to_hub and not _internal_call:
- self.push_to_hub(commit_message="Model save", revision=self.args.hub_revision)
- def _save(self, output_dir: str | None = None, state_dict: dict | None = None) -> None:
- """Save model weights, configuration, and processing class to `output_dir`."""
- # If we are executing this function, we are the process zero, so we don't check for that.
- output_dir = output_dir if output_dir is not None else self.args.output_dir
- os.makedirs(output_dir, exist_ok=True)
- logger.info(f"Saving model checkpoint to {output_dir}")
- supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
- # Save a trained model and configuration using `save_pretrained()`.
- # They can then be reloaded using `from_pretrained()`
- if not isinstance(self.model, supported_classes):
- if state_dict is None:
- state_dict = self.model.state_dict()
- if isinstance(self.accelerator.unwrap_model(self.model, keep_torch_compile=False), supported_classes):
- self.accelerator.unwrap_model(self.model, keep_torch_compile=False).save_pretrained(
- output_dir, state_dict=state_dict
- )
- else:
- logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
- safetensors.torch.save_file(
- state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"}
- )
- else:
- self.model.save_pretrained(output_dir, state_dict=state_dict)
- if self.processing_class is not None:
- self.processing_class.save_pretrained(output_dir)
- elif (
- self.data_collator is not None
- and hasattr(self.data_collator, "tokenizer")
- and self.data_collator.tokenizer is not None
- ):
- logger.info("Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`")
- self.data_collator.tokenizer.save_pretrained(output_dir)
- # Good practice: save your training arguments together with the trained model
- torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
- # ---- Logging & Metrics ----
- def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
- """
- Log `logs` on the various objects watching training.
- Subclass and override this method to inject custom behavior.
- Args:
- logs (`dict[str, float]`):
- The values to log.
- start_time (`Optional[float]`):
- The start of training.
- """
- if self.state.epoch is not None:
- logs["epoch"] = self.state.epoch
- if self.args.include_num_input_tokens_seen != "no":
- logs["num_input_tokens_seen"] = self.state.num_input_tokens_seen
- if start_time is not None:
- current_session_num_tokens = self.state.num_input_tokens_seen - self._initial_num_input_tokens_seen
- logs.update(speed_metrics("train", start_time, num_tokens=current_session_num_tokens))
- output = {**logs, "step": self.state.global_step}
- self.state.log_history.append(output)
- self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
- def store_flos(self) -> None:
- """Store the number of floating-point operations that went into the model."""
- if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
- self.state.total_flos += (
- distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item()
- )
- self.current_flos = 0
- else:
- self.state.total_flos += self.current_flos
- self.current_flos = 0
- def floating_point_ops(self, inputs: dict[str, torch.Tensor | Any]) -> int:
- """
- For models that inherit from [`PreTrainedModel`], uses that method to compute the number of floating point
- operations for every backward + forward pass. If using another model, either implement such a method in the
- model or subclass and override this method.
- Args:
- inputs (`dict[str, torch.Tensor | Any]`):
- The inputs and targets of the model.
- Returns:
- `int`: The number of floating-point operations.
- """
- if (main_input := getattr(self.model, "main_input_name", "input_ids")) in inputs and hasattr(
- self.model, "num_parameters"
- ):
- return 6 * inputs[main_input].numel() * self.model.num_parameters(exclude_embeddings=True)
- return 0
- # ---- Hub Integration ----
- def init_hf_repo(self, token: str | None = None) -> None:
- """
- Initializes a git repo in `self.args.hub_model_id`.
- """
- # Only on process zero
- if not self.is_world_process_zero():
- return
- if self.args.hub_model_id is None:
- repo_name = Path(self.args.output_dir).absolute().name
- else:
- repo_name = self.args.hub_model_id
- token = token if token is not None else self.args.hub_token
- repo_url = create_repo(repo_name, token=token, private=self.args.hub_private_repo, exist_ok=True)
- self.hub_model_id = repo_url.repo_id
- self.push_in_progress = None
- def create_model_card(
- self,
- language: str | None = None,
- license: str | None = None,
- tags: str | list[str] | None = None,
- model_name: str | None = None,
- finetuned_from: str | None = None,
- tasks: str | list[str] | None = None,
- dataset_tags: str | list[str] | None = None,
- dataset: str | list[str] | None = None,
- dataset_args: str | list[str] | None = None,
- ) -> None:
- """
- Creates a draft of a model card using the information available to the `Trainer`.
- Args:
- language (`str`, *optional*):
- The language of the model (if applicable)
- license (`str`, *optional*):
- The license of the model. Will default to the license of the pretrained model used, if the original
- model given to the `Trainer` comes from a repo on the Hub.
- tags (`str` or `list[str]`, *optional*):
- Some tags to be included in the metadata of the model card.
- model_name (`str`, *optional*):
- The name of the model.
- finetuned_from (`str`, *optional*):
- The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo
- of the original model given to the `Trainer` (if it comes from the Hub).
- tasks (`str` or `list[str]`, *optional*):
- One or several task identifiers, to be included in the metadata of the model card.
- dataset_tags (`str` or `list[str]`, *optional*):
- One or several dataset tags, to be included in the metadata of the model card.
- dataset (`str` or `list[str]`, *optional*):
- One or several dataset identifiers, to be included in the metadata of the model card.
- dataset_args (`str` or `list[str]`, *optional*):
- One or several dataset arguments, to be included in the metadata of the model card.
- """
- if not self.is_world_process_zero():
- return
- model_card_filepath = os.path.join(self.args.output_dir, "README.md")
- is_peft_library = False
- if os.path.exists(model_card_filepath):
- library_name = ModelCard.load(model_card_filepath).data.get("library_name")
- is_peft_library = library_name == "peft"
- # Append existing tags in `tags`
- existing_tags = ModelCard.load(model_card_filepath).data.tags
- if tags is not None and existing_tags is not None:
- if isinstance(tags, str):
- tags = [tags]
- for tag in existing_tags:
- if tag not in tags:
- tags.append(tag)
- training_summary = TrainingSummary.from_trainer(
- self,
- language=language,
- license=license,
- tags=tags,
- model_name=model_name,
- finetuned_from=finetuned_from,
- tasks=tasks,
- dataset_tags=dataset_tags,
- dataset=dataset,
- dataset_args=dataset_args,
- )
- model_card = training_summary.to_model_card()
- with open(model_card_filepath, "w") as f:
- f.write(model_card)
- if is_peft_library:
- self.accelerator.unwrap_model(self.model).create_or_update_model_card(self.args.output_dir)
- def push_to_hub(
- self,
- commit_message: str | None = "End of training",
- blocking: bool = True,
- token: str | None = None,
- revision: str | None = None,
- **kwargs,
- ) -> CommitInfo:
- """
- Upload `self.model` and `self.processing_class` to the 🤗 model hub on the repo `self.args.hub_model_id`.
- Parameters:
- commit_message (`str`, *optional*, defaults to `"End of training"`):
- Message to commit while pushing.
- blocking (`bool`, *optional*, defaults to `True`):
- Whether the function should return only when the `git push` has finished.
- token (`str`, *optional*, defaults to `None`):
- Token with write permission to overwrite Trainer's original args.
- revision (`str`, *optional*):
- The git revision to commit from. Defaults to the head of the "main" branch.
- kwargs (`dict[str, Any]`, *optional*):
- Additional keyword arguments passed along to [`~Trainer.create_model_card`].
- Returns:
- The URL of the repository where the model was pushed if `blocking=False`, or a `Future` object tracking the
- progress of the commit if `blocking=True`.
- """
- self.callback_handler.on_push_begin(self.args, self.state, self.control)
- model_name = kwargs.pop("model_name", None)
- if model_name is None and self.args.should_save:
- if self.args.hub_model_id is None:
- model_name = Path(self.args.output_dir).name
- else:
- model_name = self.args.hub_model_id.split("/")[-1]
- token = token if token is not None else self.args.hub_token
- # In case the user calls this method with args.push_to_hub = False
- if self.hub_model_id is None:
- self.init_hf_repo(token=token)
- # Needs to be executed on all processes for TPU training, but will only save on the processed determined by
- # self.args.should_save.
- self.save_model(_internal_call=True)
- # Only push from one node.
- if not self.is_world_process_zero():
- return
- # Add additional tags in the case the model has already some tags and users pass
- # "tags" argument to `push_to_hub` so that trainer automatically handles internal tags
- # from all models since Trainer does not call `model.push_to_hub`.
- if getattr(self.model, "model_tags", None) is not None:
- if "tags" not in kwargs:
- kwargs["tags"] = []
- # If it is a string, convert it to a list
- if isinstance(kwargs["tags"], str):
- kwargs["tags"] = [kwargs["tags"]]
- for model_tag in self.model.model_tags:
- if model_tag not in kwargs["tags"]:
- kwargs["tags"].append(model_tag)
- self.create_model_card(model_name=model_name, **kwargs)
- if revision is None:
- revision = self.args.hub_revision
- # Wait for the current upload to be finished.
- self._finish_current_push()
- return upload_folder(
- repo_id=self.hub_model_id,
- folder_path=self.args.output_dir,
- commit_message=commit_message,
- token=token,
- run_as_future=not blocking,
- ignore_patterns=["_*", f"{PREFIX_CHECKPOINT_DIR}-*"],
- revision=revision,
- )
- def _push_from_checkpoint(self, checkpoint_folder: str) -> None:
- """Push model and checkpoint files to the Hub from a checkpoint folder."""
- if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END:
- return
- # If we haven't finished the last push, we don't do this one unless args.hub_always_push=True.
- if not self.args.hub_always_push and self.push_in_progress is not None and not self.push_in_progress.is_done():
- return
- self.callback_handler.on_push_begin(self.args, self.state, self.control)
- output_dir = self.args.output_dir
- # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
- modeling_files = [CONFIG_NAME, GENERATION_CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME]
- # Add sharded checkpoints if we have an index
- for index_file in [WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]:
- index_path = os.path.join(checkpoint_folder, index_file)
- if os.path.isfile(index_path):
- modeling_files.append(index_file)
- with open(index_path) as f:
- index = json.loads(f.read())
- shard_files = list(set(index["weight_map"].values()))
- modeling_files.extend(shard_files)
- if is_peft_available():
- modeling_files.extend([ADAPTER_CONFIG_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME])
- for modeling_file in modeling_files:
- if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
- shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
- # Saving the processing class is fast and we don't know how many files it may have spawned, so we resave it to be sure.
- if self.processing_class is not None:
- self.processing_class.save_pretrained(output_dir)
- # Same for the training arguments
- torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
- if self.args.save_strategy == SaveStrategy.STEPS:
- commit_message = f"Training in progress, step {self.state.global_step}"
- else:
- commit_message = f"Training in progress, epoch {int(self.state.epoch)}"
- model_push_job = upload_folder(
- repo_id=self.hub_model_id,
- folder_path=output_dir,
- commit_message=commit_message,
- token=self.args.hub_token,
- run_as_future=True,
- ignore_patterns=["_*", f"{PREFIX_CHECKPOINT_DIR}-*"],
- revision=self.args.hub_revision,
- )
- push_jobs = [model_push_job]
- if self.args.hub_strategy in [HubStrategy.CHECKPOINT, HubStrategy.ALL_CHECKPOINTS]:
- path_in_repo = (
- "last-checkpoint" if self.args.hub_strategy == HubStrategy.CHECKPOINT else Path(checkpoint_folder).name
- )
- checkpoint_push = upload_folder(
- repo_id=self.hub_model_id,
- folder_path=checkpoint_folder,
- path_in_repo=path_in_repo,
- commit_message=commit_message + ", checkpoint",
- token=self.args.hub_token,
- run_as_future=True,
- revision=self.args.hub_revision,
- )
- push_jobs.append(checkpoint_push)
- if self.push_in_progress is None or self.push_in_progress.is_done():
- self.push_in_progress = PushInProgress(push_jobs)
- else:
- self.push_in_progress.jobs.extend(push_jobs)
- def _finish_current_push(self) -> None:
- """Wait for any in-progress push to the Hub to complete."""
- if not hasattr(self, "push_in_progress"):
- return
- if self.push_in_progress is not None and not self.push_in_progress.is_done():
- logger.info("Waiting for the current checkpoint push to be finished, this might take a couple of minutes.")
- self.push_in_progress.wait_until_done()
- # ---- Hyperparameter Search ----
- def hyperparameter_search(
- self,
- hp_space: Callable[["optuna.Trial"], dict[str, float]] | None = None,
- compute_objective: Callable[[dict[str, float]], float] | None = None,
- n_trials: int = 20,
- direction: str | list[str] = "minimize",
- backend: str | HPSearchBackend | None = None,
- hp_name: Callable[["optuna.Trial"], str] | None = None,
- **kwargs,
- ) -> BestRun | list[BestRun]:
- """
- Launch a hyperparameter search using `optuna` or `Ray Tune`. The optimized quantity is determined
- by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided,
- the sum of all metrics otherwise.
- <Tip warning={true}>
- To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to
- reinitialize the model at each new run. This is incompatible with the `optimizers` argument, so you need to
- subclass [`Trainer`] and override the method [`~Trainer.create_optimizer_and_scheduler`] for custom
- optimizer/scheduler.
- </Tip>
- Args:
- hp_space (`Callable[["optuna.Trial"], dict[str, float]]`, *optional*):
- A function that defines the hyperparameter search space. Will default to
- [`~trainer_utils.default_hp_space_optuna`] or [`~trainer_utils.default_hp_space_ray`]
- depending on your backend.
- compute_objective (`Callable[[dict[str, float]], float]`, *optional*):
- A function computing the objective to minimize or maximize from the metrics returned by the `evaluate`
- method. Will default to [`~trainer_utils.default_compute_objective`].
- n_trials (`int`, *optional*, defaults to 100):
- The number of trial runs to test.
- direction (`str` or `list[str]`, *optional*, defaults to `"minimize"`):
- If it's single objective optimization, direction is `str`, can be `"minimize"` or `"maximize"`, you
- should pick `"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or
- several metrics. If it's multi objectives optimization, direction is `list[str]`, can be List of
- `"minimize"` and `"maximize"`, you should pick `"minimize"` when optimizing the validation loss,
- `"maximize"` when optimizing one or several metrics.
- backend (`str` or [`~training_utils.HPSearchBackend`], *optional*):
- The backend to use for hyperparameter search. Will default to optuna or Ray Tune, depending
- on which one is installed. If all are installed, will default to optuna.
- hp_name (`Callable[["optuna.Trial"], str]]`, *optional*):
- A function that defines the trial/run name. Will default to None.
- kwargs (`dict[str, Any]`, *optional*):
- Additional keyword arguments for each backend:
- - `optuna`: parameters from
- [optuna.study.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html)
- and also the parameters `timeout`, `n_jobs` and `gc_after_trial` from
- [optuna.study.Study.optimize](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.Study.html#optuna.study.Study.optimize)
- - `ray`: parameters from [tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run).
- If `resources_per_trial` is not set in the `kwargs`, it defaults to 1 CPU core and 1 GPU (if available).
- If `progress_reporter` is not set in the `kwargs`,
- [ray.tune.CLIReporter](https://docs.ray.io/en/latest/tune/api/doc/ray.tune.CLIReporter.html) is used.
- Returns:
- [`trainer_utils.BestRun` or `list[trainer_utils.BestRun]`]: All the information about the best run or best
- runs for multi-objective optimization. Experiment summary can be found in `run_summary` attribute for Ray
- backend.
- """
- if backend is None:
- backend = default_hp_search_backend()
- backend = HPSearchBackend(backend)
- backend_obj = ALL_HYPERPARAMETER_SEARCH_BACKENDS[backend]()
- backend_obj.ensure_available()
- self.hp_search_backend = backend
- if self.model_init is None:
- raise RuntimeError(
- "To use hyperparameter search, you need to pass your model through a model_init function."
- )
- self.hp_space = backend_obj.default_hp_space if hp_space is None else hp_space
- self.hp_name = hp_name
- self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
- best_run = backend_obj.run(self, n_trials, direction, **kwargs)
- self.hp_search_backend = None
- return best_run
- def call_model_init(self, trial: "optuna.Trial | dict[str, Any] | None" = None) -> nn.Module:
- """Invoke `model_init` to get a fresh model instance, optionally conditioned on a hyperparameter trial."""
- model_init_argcount = number_of_arguments(self.model_init)
- if model_init_argcount == 0:
- model = self.model_init()
- elif model_init_argcount == 1:
- model = self.model_init(trial)
- else:
- raise RuntimeError("model_init should have 0 or 1 argument.")
- if model is None:
- raise RuntimeError("model_init should not return None.")
- return model
- def _hp_search_setup(self, trial: "optuna.Trial | dict[str, Any] | None") -> None:
- """Set up training arguments and accelerator state for a hyperparameter search trial."""
- self._trial = trial
- if self.hp_search_backend is None or trial is None:
- return
- if self.hp_search_backend == HPSearchBackend.OPTUNA:
- params = self.hp_space(trial)
- elif self.hp_search_backend == HPSearchBackend.RAY:
- params = trial
- params.pop("wandb", None)
- elif self.hp_search_backend == HPSearchBackend.WANDB:
- params = trial
- for key, value in params.items():
- if not hasattr(self.args, key):
- logger.warning(
- f"Trying to set {key} in the hyperparameter search but there is no corresponding field in"
- " `TrainingArguments`."
- )
- continue
- old_attr = getattr(self.args, key, None)
- # Casting value to the proper type
- if old_attr is not None:
- value = type(old_attr)(value)
- setattr(self.args, key, value)
- if self.hp_search_backend == HPSearchBackend.OPTUNA:
- logger.info(f"Trial: {trial.params}")
- if self.hp_search_backend == HPSearchBackend.WANDB:
- logger.info(f"W&B Sweep parameters: {trial}")
- if self.is_deepspeed_enabled:
- if self.args.deepspeed is None:
- raise ValueError("For sweeps with deepspeed, `args.deepspeed` must be set")
- self.accelerator.free_memory()
- # Rebuild the deepspeed config to reflect the updated training parameters
- from accelerate.utils import DeepSpeedPlugin
- from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
- self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed)
- self.args.hf_deepspeed_config.trainer_config_process(self.args)
- self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config)
- # From 1.0 on, we need to fully wipe the DS plugin when doing sweeps.
- # Simply calling `_reset_state` is enough and doesn't need a version pin.
- AcceleratorState()._reset_state()
- # `train_batch_size` might change when using HPO https://github.com/huggingface/transformers/pull/18918
- self._train_batch_size = self.args.train_batch_size
- self.create_accelerator_and_postprocess()
- def _report_to_hp_search(
- self, trial: "optuna.Trial | dict[str, Any] | None", step: int, metrics: dict[str, float]
- ) -> None:
- """Report intermediate metrics to the active hyperparameter search backend."""
- if self.hp_search_backend is None or trial is None:
- return
- metrics = metrics.copy()
- self.objective = self.compute_objective(metrics)
- if self.hp_search_backend == HPSearchBackend.OPTUNA:
- import optuna
- if hasattr(trial, "study") and not trial.study._is_multi_objective():
- trial.report(self.objective, step)
- if trial.should_prune():
- self.callback_handler.on_train_end(self.args, self.state, self.control)
- raise optuna.TrialPruned()
- elif self.hp_search_backend == HPSearchBackend.RAY:
- import ray.tune
- with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
- checkpoint = None
- if self.control.should_save:
- self._tune_save_checkpoint(checkpoint_dir=temp_checkpoint_dir)
- checkpoint = ray.tune.Checkpoint.from_directory(temp_checkpoint_dir)
- metrics["objective"] = self.objective
- ray.tune.report(metrics, checkpoint=checkpoint)
- def _tune_save_checkpoint(self, checkpoint_dir: str) -> None:
- """Save a checkpoint during a Ray Tune hyperparameter search trial."""
- output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
- self.save_model(output_dir, _internal_call=True)
- if self.args.should_save:
- # Update the `TrainerControl` state to where we are currently
- self.state.stateful_callbacks["TrainerControl"] = self.control.state()
- self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
- torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
- torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
- # ---- Callbacks ----
- def add_callback(self, callback: type[TrainerCallback] | TrainerCallback) -> None:
- """
- Add a callback to the current list of [`~transformers.TrainerCallback`].
- Args:
- callback (`type` or [`~transformers.TrainerCallback]`):
- A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the
- first case, will instantiate a member of that class.
- """
- self.callback_handler.add_callback(callback)
- def pop_callback(self, callback: type[TrainerCallback] | TrainerCallback) -> TrainerCallback | None:
- """
- Remove a callback from the current list of [`~transformers.TrainerCallback`] and returns it.
- If the callback is not found, returns `None` (and no error is raised).
- Args:
- callback (`type` or [`~transformers.TrainerCallback]`):
- A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the
- first case, will pop the first member of that class found in the list of callbacks.
- Returns:
- [`~transformers.TrainerCallback`]: The callback removed, if found.
- """
- return self.callback_handler.pop_callback(callback)
- def remove_callback(self, callback: type[TrainerCallback] | TrainerCallback) -> None:
- """
- Remove a callback from the current list of [`~transformers.TrainerCallback`].
- Args:
- callback (`type` or [`~transformers.TrainerCallback]`):
- A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the
- first case, will remove the first member of that class found in the list of callbacks.
- """
- self.callback_handler.remove_callback(callback)
- # ---- Utilities ----
- def is_local_process_zero(self) -> bool:
- """
- Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
- machines) main process.
- """
- return self.args.local_process_index == 0
- def is_world_process_zero(self) -> bool:
- """
- Whether or not this process is the global main process (when training in a distributed fashion on several
- machines, this is only going to be `True` for one process).
- """
- # Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global
- # process index.
- if is_sagemaker_mp_enabled():
- return smp.rank() == 0
- return self.args.process_index == 0
- def _move_model_to_device(self, model: nn.Module, device: torch.device) -> None:
- """Move the model to the specified device, re-tying weights on XLA if needed."""
- if getattr(model, "hf_device_map", None) is not None:
- logger.warning(
- "The model is already on multiple devices. Skipping the move to device specified in `args`."
- )
- return
- model = model.to(device)
- # Moving a model to an XLA device disconnects the tied weights, so we have to retie them.
- if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
- model.tie_weights()
|