modeling_utils.py 248 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769377037713772377337743775377637773778377937803781378237833784378537863787378837893790379137923793379437953796379737983799380038013802380338043805380638073808380938103811381238133814381538163817381838193820382138223823382438253826382738283829383038313832383338343835383638373838383938403841384238433844384538463847384838493850385138523853385438553856385738583859386038613862386338643865386638673868386938703871387238733874387538763877387838793880388138823883388438853886388738883889389038913892389338943895389638973898389939003901390239033904390539063907390839093910391139123913391439153916391739183919392039213922392339243925392639273928392939303931393239333934393539363937393839393940394139423943394439453946394739483949395039513952395339543955395639573958395939603961396239633964396539663967396839693970397139723973397439753976397739783979398039813982398339843985398639873988398939903991399239933994399539963997399839994000400140024003400440054006400740084009401040114012401340144015401640174018401940204021402240234024402540264027402840294030403140324033403440354036403740384039404040414042404340444045404640474048404940504051405240534054405540564057405840594060406140624063406440654066406740684069407040714072407340744075407640774078407940804081408240834084408540864087408840894090409140924093409440954096409740984099410041014102410341044105410641074108410941104111411241134114411541164117411841194120412141224123412441254126412741284129413041314132413341344135413641374138413941404141414241434144414541464147414841494150415141524153415441554156415741584159416041614162416341644165416641674168416941704171417241734174417541764177417841794180418141824183418441854186418741884189419041914192419341944195419641974198419942004201420242034204420542064207420842094210421142124213421442154216421742184219422042214222422342244225422642274228422942304231423242334234423542364237423842394240424142424243424442454246424742484249425042514252425342544255425642574258425942604261426242634264426542664267426842694270427142724273427442754276427742784279428042814282428342844285428642874288428942904291429242934294429542964297429842994300430143024303430443054306430743084309431043114312431343144315431643174318431943204321432243234324432543264327432843294330433143324333433443354336433743384339434043414342434343444345434643474348434943504351435243534354435543564357435843594360436143624363436443654366436743684369437043714372437343744375437643774378437943804381438243834384438543864387438843894390439143924393439443954396439743984399440044014402440344044405440644074408440944104411441244134414441544164417441844194420442144224423442444254426442744284429443044314432443344344435443644374438443944404441444244434444444544464447444844494450445144524453445444554456445744584459446044614462446344644465446644674468446944704471447244734474447544764477447844794480448144824483448444854486448744884489449044914492449344944495449644974498449945004501450245034504450545064507450845094510451145124513451445154516451745184519452045214522452345244525452645274528452945304531453245334534453545364537453845394540454145424543454445454546454745484549455045514552455345544555455645574558455945604561456245634564456545664567456845694570457145724573457445754576457745784579458045814582458345844585458645874588458945904591459245934594459545964597459845994600460146024603460446054606460746084609461046114612461346144615461646174618461946204621462246234624462546264627462846294630463146324633463446354636463746384639464046414642464346444645464646474648464946504651465246534654465546564657465846594660466146624663466446654666466746684669467046714672467346744675467646774678467946804681468246834684468546864687468846894690469146924693469446954696469746984699470047014702470347044705470647074708470947104711471247134714471547164717471847194720472147224723472447254726472747284729473047314732473347344735473647374738473947404741474247434744474547464747474847494750475147524753475447554756475747584759476047614762476347644765476647674768476947704771477247734774477547764777477847794780478147824783478447854786478747884789479047914792479347944795479647974798479948004801480248034804480548064807480848094810481148124813481448154816481748184819482048214822482348244825482648274828482948304831483248334834483548364837483848394840484148424843484448454846484748484849485048514852485348544855485648574858485948604861486248634864486548664867486848694870487148724873487448754876487748784879488048814882488348844885488648874888488948904891489248934894
  1. # Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
  2. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import collections
  16. import copy
  17. import functools
  18. import inspect
  19. import json
  20. import os
  21. import re
  22. import sys
  23. import warnings
  24. from abc import abstractmethod
  25. from collections import defaultdict
  26. from collections.abc import Callable, Iterator
  27. from contextlib import contextmanager
  28. from dataclasses import dataclass, field
  29. from functools import partial, wraps
  30. from itertools import cycle
  31. from threading import Thread
  32. from typing import Optional, TypeVar, get_type_hints
  33. from zipfile import is_zipfile
  34. import torch
  35. from huggingface_hub import create_repo, is_offline_mode, split_torch_state_dict_into_shards
  36. from packaging import version
  37. from safetensors import safe_open
  38. from safetensors.torch import save_file as safe_save_file
  39. from torch import Tensor, nn
  40. from torch.distributions import constraints
  41. from torch.utils.checkpoint import checkpoint
  42. from . import initialization as init
  43. from .configuration_utils import PreTrainedConfig
  44. from .conversion_mapping import get_model_conversion_mapping
  45. from .core_model_loading import (
  46. WeightConverter,
  47. WeightRenaming,
  48. convert_and_load_state_dict_in_model,
  49. revert_weight_conversion,
  50. )
  51. from .distributed import DistributedConfig
  52. from .dynamic_module_utils import custom_object_save
  53. from .generation import CompileConfig, GenerationConfig
  54. from .integrations import PeftAdapterMixin, deepspeed_config, hub_kernels, is_deepspeed_zero3_enabled, is_fsdp_enabled
  55. from .integrations.accelerate import (
  56. _get_device_map,
  57. accelerate_disk_offload,
  58. accelerate_dispatch,
  59. check_and_set_device_map,
  60. expand_device_map,
  61. get_device,
  62. load_offloaded_parameter,
  63. )
  64. from .integrations.deepspeed import _load_state_dict_into_zero3_model
  65. from .integrations.eager_paged import eager_paged_attention_forward
  66. from .integrations.flash_attention import flash_attention_forward
  67. from .integrations.flash_paged import paged_attention_forward
  68. from .integrations.flex_attention import flex_attention_forward
  69. from .integrations.hub_kernels import allow_all_hub_kernels, is_kernel
  70. from .integrations.peft import maybe_load_adapters
  71. from .integrations.sdpa_attention import sdpa_attention_forward
  72. from .integrations.sdpa_paged import sdpa_attention_paged_forward
  73. from .integrations.tensor_parallel import (
  74. ALL_PARALLEL_STYLES,
  75. _get_parameter_tp_plan,
  76. distribute_model,
  77. gather_state_dict_for_save,
  78. initialize_tensor_parallelism,
  79. shard_and_distribute_module,
  80. verify_tp_plan,
  81. )
  82. from .loss.loss_utils import LOSS_MAPPING
  83. from .modeling_flash_attention_utils import (
  84. FLASH_ATTENTION_COMPATIBILITY_MATRIX,
  85. FLASH_ATTN_KERNEL_FALLBACK,
  86. lazy_import_flash_attention,
  87. lazy_import_paged_flash_attention,
  88. )
  89. from .modeling_rope_utils import ROPE_INIT_FUNCTIONS
  90. from .monkey_patching import apply_patches, patch_output_recorders
  91. from .pytorch_utils import id_tensor_storage
  92. from .quantizers import HfQuantizer
  93. from .quantizers.auto import get_hf_quantizer
  94. from .quantizers.quantizers_utils import get_module_from_name
  95. from .safetensors_conversion import auto_conversion
  96. from .utils import (
  97. ADAPTER_SAFE_WEIGHTS_NAME,
  98. DUMMY_INPUTS,
  99. SAFE_WEIGHTS_INDEX_NAME,
  100. SAFE_WEIGHTS_NAME,
  101. WEIGHTS_INDEX_NAME,
  102. WEIGHTS_NAME,
  103. ContextManagers,
  104. KernelConfig,
  105. PushToHubMixin,
  106. cached_file,
  107. check_torch_load_is_safe,
  108. copy_func,
  109. has_file,
  110. is_accelerate_available,
  111. is_bitsandbytes_available,
  112. is_env_variable_true,
  113. is_kernels_available,
  114. is_torch_flex_attn_available,
  115. is_torch_npu_available,
  116. is_torch_xpu_available,
  117. logging,
  118. )
  119. from .utils.generic import GeneralInterface, is_flash_attention_requested
  120. from .utils.hub import DownloadKwargs, create_and_tag_model_card, get_checkpoint_shard_files
  121. from .utils.import_utils import (
  122. is_flash_attn_greater_or_equal,
  123. is_huggingface_hub_greater_or_equal,
  124. is_sagemaker_mp_enabled,
  125. is_torch_cuda_available,
  126. is_tracing,
  127. )
  128. from .utils.loading_report import LoadStateDictInfo, log_state_dict_report
  129. from .utils.output_capturing import _CAN_RECORD_REGISTRY, OutputRecorder
  130. from .utils.quantization_config import QuantizationMethod
  131. if is_accelerate_available():
  132. from accelerate.hooks import add_hook_to_module
  133. from accelerate.utils import extract_model_from_parallel
  134. _torch_distributed_available = torch.distributed.is_available()
  135. if is_sagemaker_mp_enabled():
  136. import smdistributed.modelparallel.torch as smp
  137. from smdistributed.modelparallel import __version__ as SMP_VERSION
  138. IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
  139. else:
  140. IS_SAGEMAKER_MP_POST_1_10 = False
  141. logger = logging.get_logger(__name__)
  142. XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
  143. XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
  144. SpecificPreTrainedModelType = TypeVar("SpecificPreTrainedModelType", bound="PreTrainedModel")
  145. _is_quantized = False
  146. _is_ds_init_called = False
  147. @dataclass(frozen=True)
  148. class LoadStateDictConfig:
  149. """
  150. Config for loading weights. This allows bundling arguments that are just
  151. passed around.
  152. """
  153. pretrained_model_name_or_path: str | None = None
  154. download_kwargs: DownloadKwargs | None = field(default_factory=DownloadKwargs)
  155. use_safetensors: bool | None = None
  156. ignore_mismatched_sizes: bool = False
  157. sharded_metadata: dict | None = None
  158. device_map: dict | None = None
  159. disk_offload_folder: str | None = None
  160. offload_buffers: bool = False
  161. dtype: torch.dtype | None = None
  162. dtype_plan: dict = field(default_factory=dict)
  163. hf_quantizer: HfQuantizer | None = None
  164. device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None
  165. weights_only: bool = True
  166. weight_mapping: list[WeightConverter | WeightRenaming] | None = None
  167. @property
  168. def is_quantized(self) -> bool:
  169. return self.hf_quantizer is not None
  170. def is_local_dist_rank_0():
  171. return (
  172. torch.distributed.is_available()
  173. and torch.distributed.is_initialized()
  174. and int(os.environ.get("LOCAL_RANK", "-1")) == 0
  175. )
  176. @contextmanager
  177. def set_quantized_state():
  178. global _is_quantized
  179. _is_quantized = True
  180. try:
  181. yield
  182. finally:
  183. _is_quantized = False
  184. # Skip recursive calls to deepspeed.zero.Init to avoid pinning errors.
  185. # This issue occurs with ZeRO stage 3 when using NVMe offloading.
  186. # For more details, refer to issue #34429.
  187. @contextmanager
  188. def set_zero3_state():
  189. global _is_ds_init_called
  190. _is_ds_init_called = True
  191. try:
  192. yield
  193. finally:
  194. _is_ds_init_called = False
  195. @contextmanager
  196. def local_torch_dtype(dtype: torch.dtype, model_class_name: str | None = None):
  197. """
  198. Locally change the torch default dtype to `dtype`, and restore the old one upon exiting the context.
  199. If `model_class_name` is provided, it's used to provide a more helpful error message if `dtype` is not valid.
  200. """
  201. # Just a more helping error before we set `torch.set_default_dtype` later on which would crash in this case
  202. if not dtype.is_floating_point:
  203. if model_class_name is not None:
  204. error_message = (
  205. f"{model_class_name} cannot be instantiated under `dtype={dtype}` as it's not a floating-point dtype"
  206. )
  207. else:
  208. error_message = f"Cannot set `{dtype}` as torch's default as it's not a floating-point dtype"
  209. raise ValueError(error_message)
  210. original_dtype = torch.get_default_dtype()
  211. try:
  212. torch.set_default_dtype(dtype)
  213. yield
  214. finally:
  215. torch.set_default_dtype(original_dtype)
  216. def get_torch_context_manager_or_global_device():
  217. """
  218. Test if a device context manager is currently in use, or if it is not the case, check if the default device
  219. is not "cpu". This is used to infer the correct device to load the model on, in case `device_map` is not provided.
  220. """
  221. device_in_context = torch.tensor([]).device
  222. default_device = torch.get_default_device()
  223. # This case means no context manager was used -> we still check if the default that was potentially set is not cpu
  224. if device_in_context == default_device:
  225. if default_device != torch.device("cpu"):
  226. return default_device
  227. return None
  228. return device_in_context
  229. def get_state_dict_dtype(state_dict):
  230. """
  231. Returns the first found floating dtype in `state_dict` if there is one, otherwise returns the first dtype.
  232. """
  233. for t in state_dict.values():
  234. # We cannot instantiate a whole model under float4/8_xxx dtypes (torch does not allow setting them as default dtype)
  235. if t.is_floating_point() and "float8_" not in str(t.dtype) and "float4_" not in str(t.dtype):
  236. return t.dtype
  237. # if no floating dtype was found return whatever the first dtype is
  238. if len(state_dict) == 0:
  239. return torch.float32
  240. return next(iter(state_dict.values())).dtype
  241. str_to_torch_dtype = {
  242. "BOOL": torch.bool,
  243. "U8": torch.uint8,
  244. "I8": torch.int8,
  245. "I16": torch.int16,
  246. "U16": torch.uint16,
  247. "F16": torch.float16,
  248. "BF16": torch.bfloat16,
  249. "I32": torch.int32,
  250. "U32": torch.uint32,
  251. "F32": torch.float32,
  252. "F64": torch.float64,
  253. "I64": torch.int64,
  254. "U64": torch.uint64,
  255. "F8_E4M3": torch.float8_e4m3fn,
  256. "F8_E5M2": torch.float8_e5m2,
  257. }
  258. def load_state_dict(
  259. checkpoint_file: str | os.PathLike, map_location: str | torch.device = "cpu", weights_only: bool = True
  260. ) -> dict[str, torch.Tensor]:
  261. """
  262. Reads a `safetensor` or a `.bin` checkpoint file. We load the checkpoint on "cpu" by default.
  263. """
  264. # Use safetensors if possible
  265. if checkpoint_file.endswith(".safetensors"):
  266. with safe_open(checkpoint_file, framework="pt") as f:
  267. state_dict = {}
  268. for k in f.keys():
  269. if map_location == "meta":
  270. _slice = f.get_slice(k)
  271. k_dtype = _slice.get_dtype()
  272. if k_dtype in str_to_torch_dtype:
  273. dtype = str_to_torch_dtype[k_dtype]
  274. else:
  275. raise ValueError(f"Cannot load safetensors of unknown dtype {k_dtype}")
  276. state_dict[k] = torch.empty(size=_slice.get_shape(), dtype=dtype, device="meta")
  277. else:
  278. state_dict[k] = f.get_tensor(k).to(map_location)
  279. return state_dict
  280. # Fallback to torch.load (if weights_only was explicitly False, do not check safety as this is known to be unsafe)
  281. if weights_only:
  282. check_torch_load_is_safe()
  283. extra_args = {}
  284. # mmap can only be used with files serialized with zipfile-based format.
  285. if isinstance(checkpoint_file, str) and map_location != "meta" and is_zipfile(checkpoint_file):
  286. extra_args = {"mmap": True}
  287. return torch.load(checkpoint_file, map_location=map_location, weights_only=weights_only, **extra_args)
  288. def _end_ptr(tensor: torch.Tensor) -> int:
  289. # extract the end of the pointer if the tensor is a slice of a bigger tensor
  290. if tensor.nelement():
  291. stop = tensor.view(-1)[-1].data_ptr() + tensor.element_size()
  292. else:
  293. stop = tensor.data_ptr()
  294. return stop
  295. def _get_tied_weight_keys(module: nn.Module) -> list[str]:
  296. tied_weight_keys: list[str] = []
  297. for name, submodule in module.named_modules():
  298. tied = getattr(submodule, "_tied_weights_keys", {}) or {}
  299. tied_weight_keys.extend([f"{name}.{k}" if name else k for k in tied.keys()])
  300. return tied_weight_keys
  301. def _find_disjoint(tensors: list[set[str]], state_dict: dict[str, torch.Tensor]) -> tuple[list[set[str]], list[str]]:
  302. filtered_tensors = []
  303. for shared in tensors:
  304. if len(shared) < 2:
  305. filtered_tensors.append(shared)
  306. continue
  307. areas = []
  308. for name in shared:
  309. tensor = state_dict[name]
  310. areas.append((tensor.data_ptr(), _end_ptr(tensor), name))
  311. areas.sort()
  312. _, last_stop, last_name = areas[0]
  313. filtered_tensors.append({last_name})
  314. for start, stop, name in areas[1:]:
  315. if start >= last_stop:
  316. filtered_tensors.append({name})
  317. else:
  318. filtered_tensors[-1].add(name)
  319. last_stop = stop
  320. disjoint_tensors = []
  321. shared_tensors = []
  322. for tensors in filtered_tensors:
  323. if len(tensors) == 1:
  324. disjoint_tensors.append(tensors.pop())
  325. else:
  326. shared_tensors.append(tensors)
  327. return shared_tensors, disjoint_tensors
  328. def _find_identical(tensors: list[set[str]], state_dict: dict[str, torch.Tensor]) -> tuple[list[set[str]], set[str]]:
  329. shared_tensors = []
  330. identical = []
  331. for shared in tensors:
  332. if len(shared) < 2:
  333. continue
  334. areas = collections.defaultdict(set)
  335. for name in shared:
  336. tensor = state_dict[name]
  337. area = (tensor.device, tensor.data_ptr(), _end_ptr(tensor))
  338. areas[area].add(name)
  339. if len(areas) == 1:
  340. identical.append(shared)
  341. else:
  342. shared_tensors.append(shared)
  343. return shared_tensors, identical
  344. def remove_tied_weights_from_state_dict(
  345. state_dict: dict[str, torch.Tensor], model: "PreTrainedModel"
  346. ) -> dict[str, torch.Tensor]:
  347. """
  348. Remove all tied weights from the given `state_dict`, making sure to keep only the main weight that `model`
  349. will expect when reloading (even if we now tie weights symmetrically, it's better to keep the intended one).
  350. This is because `safetensors` does not allow tensor aliasing - so we're going to remove aliases before saving.
  351. """
  352. # To avoid any potential mistakes and mismatches between config and actual tied weights, here we check the pointers
  353. # of the Tensors themselves -> we are guaranteed to find all the actual tied weights
  354. ptrs = collections.defaultdict(list)
  355. for name, tensor in state_dict.items():
  356. if not isinstance(tensor, torch.Tensor):
  357. # Sometimes in the state_dict we have non-tensor objects.
  358. # e.g. in bitsandbytes we have some `str` objects in the state_dict
  359. # In the non-tensor case, fall back to the pointer of the object itself
  360. ptrs[id(tensor)].append(name)
  361. elif tensor.device.type == "meta":
  362. # In offloaded cases, there may be meta tensors in the state_dict.
  363. # For these cases, key by the pointer of the original tensor object
  364. # (state_dict tensors are detached and therefore no longer shared)
  365. tensor = model.get_parameter(name)
  366. ptrs[id(tensor)].append(name)
  367. else:
  368. ptrs[id_tensor_storage(tensor)].append(name)
  369. shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
  370. # Recursively descend to find tied weight keys
  371. all_potential_tied_weights_keys = set(_get_tied_weight_keys(model))
  372. error_names = []
  373. to_delete_names = set()
  374. # Removing the keys which are declared as known duplicates on load. This allows to make sure the name which is
  375. # kept is consistent
  376. if all_potential_tied_weights_keys is not None:
  377. for names in shared_ptrs.values():
  378. found = 0
  379. for name in sorted(names):
  380. matches_pattern = any(re.search(pat, name) for pat in all_potential_tied_weights_keys)
  381. if matches_pattern and name in state_dict:
  382. found += 1
  383. if found < len(names):
  384. to_delete_names.add(name)
  385. # We are entering a place where the weights and the transformers configuration do NOT match.
  386. shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict)
  387. # Those are actually tensor sharing but disjoint from each other, we can safely clone them
  388. # Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
  389. for name in disjoint_names:
  390. state_dict[name] = state_dict[name].clone()
  391. # When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
  392. # If the link between tensors was done at runtime then `from_pretrained` will not get
  393. # the key back leading to random tensor. A proper warning will be shown
  394. # during reload (if applicable), but since the file is not necessarily compatible with
  395. # the config, better show a proper warning.
  396. shared_names, identical_names = _find_identical(shared_names, state_dict)
  397. # delete tensors that have identical storage
  398. for inames in identical_names:
  399. known = inames.intersection(to_delete_names)
  400. for name in known:
  401. del state_dict[name]
  402. unknown = inames.difference(to_delete_names)
  403. if len(unknown) > 1:
  404. error_names.append(unknown)
  405. if shared_names:
  406. error_names.extend(shared_names)
  407. if len(error_names) > 0:
  408. raise RuntimeError(
  409. f"The weights trying to be saved contained shared tensors {error_names} which are not properly defined. "
  410. f"We found all the potential target tied weights keys to be: {all_potential_tied_weights_keys}.\n"
  411. "This can also just mean that the module's tied weight keys are wrong vs the actual tied weights in the model.",
  412. )
  413. return state_dict
  414. def _load_parameter_into_model(model: "PreTrainedModel", param_name: str, tensor: torch.Tensor):
  415. """Cast a single parameter or buffer `param_name` into the `model`, with value `tensor`."""
  416. parent, param_type = get_module_from_name(model, param_name)
  417. if param_type in parent._parameters and not isinstance(tensor, nn.Parameter):
  418. tensor = nn.Parameter(tensor, requires_grad=tensor.is_floating_point())
  419. # We need to use setattr here, as we set non-persistent buffers as well with this function (`load_state_dict`
  420. # does not allow to do it)
  421. setattr(parent, param_type, tensor)
  422. def _add_variant(weights_name: str, variant: str | None = None) -> str:
  423. if variant is not None:
  424. path, name = weights_name.rsplit(".", 1)
  425. weights_name = f"{path}.{variant}.{name}"
  426. return weights_name
  427. def _get_resolved_checkpoint_files(
  428. pretrained_model_name_or_path: str | os.PathLike | None,
  429. variant: str | None,
  430. gguf_file: str | None,
  431. use_safetensors: bool | None,
  432. user_agent: dict | None,
  433. is_remote_code: bool, # Because we can't determine this inside this function, we need it to be passed in
  434. transformers_explicit_filename: str | None = None,
  435. download_kwargs: DownloadKwargs | None = None,
  436. tqdm_class: type | None = None,
  437. ) -> tuple[list[str] | None, dict | None]:
  438. """Get all the checkpoint filenames based on `pretrained_model_name_or_path`, and optional metadata if the
  439. checkpoints are sharded.
  440. This function will download the data if necessary.
  441. """
  442. download_kwargs = download_kwargs or DownloadKwargs()
  443. cache_dir = download_kwargs.get("cache_dir")
  444. force_download = download_kwargs.get("force_download", False)
  445. proxies = download_kwargs.get("proxies")
  446. local_files_only = download_kwargs.get("local_files_only", False)
  447. token = download_kwargs.get("token")
  448. revision = download_kwargs.get("revision") or "main"
  449. subfolder = download_kwargs.get("subfolder", "")
  450. commit_hash = download_kwargs.get("commit_hash")
  451. if transformers_explicit_filename is not None:
  452. if not transformers_explicit_filename.endswith(".safetensors") and not transformers_explicit_filename.endswith(
  453. ".safetensors.index.json"
  454. ):
  455. if transformers_explicit_filename != "adapter_model.bin":
  456. raise ValueError(
  457. "The transformers file in the config seems to be incorrect: it is neither a safetensors file "
  458. "(*.safetensors) nor a safetensors index file (*.safetensors.index.json): "
  459. f"{transformers_explicit_filename}"
  460. )
  461. is_sharded = False
  462. if pretrained_model_name_or_path is not None and gguf_file is None:
  463. pretrained_model_name_or_path = str(pretrained_model_name_or_path)
  464. is_local = os.path.isdir(pretrained_model_name_or_path)
  465. # If the file is a local folder (but not in the HF_HOME cache, even if it's technically local)
  466. if is_local:
  467. if transformers_explicit_filename is not None:
  468. # If the filename is explicitly defined, load this by default.
  469. archive_file = os.path.join(pretrained_model_name_or_path, subfolder, transformers_explicit_filename)
  470. is_sharded = transformers_explicit_filename.endswith(".safetensors.index.json")
  471. elif use_safetensors is not False and os.path.isfile(
  472. os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant))
  473. ):
  474. # Load from a safetensors checkpoint
  475. archive_file = os.path.join(
  476. pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)
  477. )
  478. elif use_safetensors is not False and os.path.isfile(
  479. os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant))
  480. ):
  481. # Load from a sharded safetensors checkpoint
  482. archive_file = os.path.join(
  483. pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
  484. )
  485. is_sharded = True
  486. elif not use_safetensors and os.path.isfile(
  487. os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))
  488. ):
  489. # Load from a PyTorch checkpoint
  490. archive_file = os.path.join(
  491. pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)
  492. )
  493. elif not use_safetensors and os.path.isfile(
  494. os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant))
  495. ):
  496. # Load from a sharded PyTorch checkpoint
  497. archive_file = os.path.join(
  498. pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
  499. )
  500. is_sharded = True
  501. elif use_safetensors:
  502. raise OSError(
  503. f"Error no file named {_add_variant(SAFE_WEIGHTS_NAME, variant)} found in directory"
  504. f" {pretrained_model_name_or_path}."
  505. )
  506. else:
  507. raise OSError(
  508. f"Error no file named {_add_variant(SAFE_WEIGHTS_NAME, variant)}, or {_add_variant(WEIGHTS_NAME, variant)},"
  509. f" found in directory {pretrained_model_name_or_path}."
  510. )
  511. elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
  512. archive_file = pretrained_model_name_or_path
  513. is_local = True
  514. else:
  515. # set correct filename
  516. if transformers_explicit_filename is not None:
  517. filename = transformers_explicit_filename
  518. is_sharded = transformers_explicit_filename.endswith(".safetensors.index.json")
  519. elif use_safetensors is not False:
  520. filename = _add_variant(SAFE_WEIGHTS_NAME, variant)
  521. else:
  522. filename = _add_variant(WEIGHTS_NAME, variant)
  523. # Prepare set of kwargs for hub functions
  524. has_file_kwargs = {
  525. "revision": revision,
  526. "proxies": proxies,
  527. "token": token,
  528. "cache_dir": cache_dir,
  529. "local_files_only": local_files_only,
  530. }
  531. cached_file_kwargs = {
  532. "force_download": force_download,
  533. "user_agent": user_agent,
  534. "subfolder": subfolder,
  535. "_raise_exceptions_for_gated_repo": False,
  536. "_raise_exceptions_for_missing_entries": False,
  537. "_commit_hash": commit_hash,
  538. "tqdm_class": tqdm_class,
  539. **has_file_kwargs,
  540. }
  541. can_auto_convert = (
  542. not is_offline_mode() # for obvious reasons
  543. # If we are in a CI environment or in a pytest run, we prevent the conversion
  544. and not is_env_variable_true("DISABLE_SAFETENSORS_CONVERSION")
  545. and not is_remote_code # converter bot does not work on remote code
  546. and subfolder == "" # converter bot does not work on subfolders
  547. )
  548. try:
  549. # Load from URL or cache if already cached
  550. # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
  551. # result when internet is up, the repo and revision exist, but the file does not.
  552. resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
  553. # Try safetensors files first if not already found
  554. if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant):
  555. # Maybe the checkpoint is sharded, we try to grab the index name in this case.
  556. resolved_archive_file = cached_file(
  557. pretrained_model_name_or_path,
  558. _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant),
  559. **cached_file_kwargs,
  560. )
  561. if resolved_archive_file is not None:
  562. is_sharded = True
  563. elif use_safetensors:
  564. if revision == "main" and can_auto_convert:
  565. resolved_archive_file, revision, is_sharded = auto_conversion(
  566. pretrained_model_name_or_path, **cached_file_kwargs
  567. )
  568. cached_file_kwargs["revision"] = revision
  569. if resolved_archive_file is None:
  570. raise OSError(
  571. f"{pretrained_model_name_or_path} does not appear to have a file named"
  572. f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} "
  573. "and thus cannot be loaded with `safetensors`. Please do not set `use_safetensors=True`."
  574. )
  575. else:
  576. # This repo has no safetensors file of any kind, we switch to PyTorch.
  577. filename = _add_variant(WEIGHTS_NAME, variant)
  578. resolved_archive_file = cached_file(
  579. pretrained_model_name_or_path, filename, **cached_file_kwargs
  580. )
  581. # Then try `.bin` files
  582. if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant):
  583. # Maybe the checkpoint is sharded, we try to grab the index name in this case.
  584. resolved_archive_file = cached_file(
  585. pretrained_model_name_or_path,
  586. _add_variant(WEIGHTS_INDEX_NAME, variant),
  587. **cached_file_kwargs,
  588. )
  589. if resolved_archive_file is not None:
  590. is_sharded = True
  591. # If we have a match, but it's `.bin` format, try to launch safetensors conversion for next time
  592. if resolved_archive_file is not None:
  593. safe_weights_name = SAFE_WEIGHTS_INDEX_NAME if is_sharded else SAFE_WEIGHTS_NAME
  594. if (
  595. filename in [WEIGHTS_NAME, WEIGHTS_INDEX_NAME]
  596. and not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs)
  597. and can_auto_convert
  598. ):
  599. Thread(
  600. target=auto_conversion,
  601. args=(pretrained_model_name_or_path,),
  602. kwargs={"ignore_errors_during_conversion": False, **cached_file_kwargs},
  603. name="Thread-auto_conversion",
  604. ).start()
  605. # If no match, raise appropriare errors
  606. else:
  607. # Otherwise, no PyTorch file was found
  608. if variant is not None and has_file(
  609. pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs
  610. ):
  611. raise OSError(
  612. f"{pretrained_model_name_or_path} does not appear to have a file named"
  613. f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant"
  614. f" {variant}. Use `variant=None` to load this model from those weights."
  615. )
  616. else:
  617. raise OSError(
  618. f"{pretrained_model_name_or_path} does not appear to have a file named"
  619. f" {_add_variant(WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_NAME, variant)}."
  620. )
  621. except OSError:
  622. # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
  623. # to the original exception.
  624. raise
  625. except Exception as e:
  626. # For any other exception, we throw a generic error.
  627. raise OSError(
  628. f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
  629. " from 'https://huggingface.co/models', make sure you don't have a local directory with the"
  630. f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
  631. f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)}."
  632. ) from e
  633. if is_local:
  634. logger.info(f"loading weights file {archive_file}")
  635. resolved_archive_file = archive_file
  636. else:
  637. logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
  638. elif gguf_file:
  639. # Case 1: the GGUF file is present locally
  640. if os.path.isfile(gguf_file):
  641. resolved_archive_file = gguf_file
  642. # Case 2: The GGUF path is a location on the Hub
  643. # Load from URL or cache if already cached
  644. else:
  645. cached_file_kwargs = {
  646. "cache_dir": cache_dir,
  647. "force_download": force_download,
  648. "proxies": proxies,
  649. "local_files_only": local_files_only,
  650. "token": token,
  651. "user_agent": user_agent,
  652. "revision": revision,
  653. "subfolder": subfolder,
  654. "_raise_exceptions_for_gated_repo": False,
  655. "_raise_exceptions_for_missing_entries": False,
  656. "_commit_hash": commit_hash,
  657. }
  658. resolved_archive_file = cached_file(pretrained_model_name_or_path, gguf_file, **cached_file_kwargs)
  659. # We now download and resolve all checkpoint files if the checkpoint is sharded
  660. sharded_metadata = None
  661. if is_sharded:
  662. checkpoint_files, sharded_metadata = get_checkpoint_shard_files(
  663. pretrained_model_name_or_path,
  664. resolved_archive_file,
  665. cache_dir=cache_dir,
  666. force_download=force_download,
  667. proxies=proxies,
  668. local_files_only=local_files_only,
  669. token=token,
  670. user_agent=user_agent,
  671. revision=revision,
  672. subfolder=subfolder,
  673. _commit_hash=commit_hash,
  674. tqdm_class=tqdm_class,
  675. )
  676. else:
  677. checkpoint_files = [resolved_archive_file] if pretrained_model_name_or_path is not None else None
  678. return checkpoint_files, sharded_metadata
  679. def _get_dtype(
  680. dtype: str | torch.dtype | dict | None,
  681. checkpoint_files: list[str] | None,
  682. config: PreTrainedConfig,
  683. sharded_metadata: dict | None,
  684. state_dict: dict | None,
  685. weights_only: bool,
  686. hf_quantizer: HfQuantizer | None = None,
  687. ) -> tuple[PreTrainedConfig, torch.dtype]:
  688. """Find the correct `dtype` to use based on provided arguments. Also update the `config` based on the
  689. inferred dtype. We do the following:
  690. 1. If dtype is "auto", we try to read the config, else auto-detect dtype from the loaded state_dict, by checking
  691. its first weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
  692. 2. Else, use the dtype provided as a dict or str
  693. """
  694. is_sharded = sharded_metadata is not None
  695. if dtype is not None:
  696. if isinstance(dtype, str):
  697. if dtype == "auto":
  698. if hasattr(config, "dtype") and config.dtype is not None:
  699. dtype = config.dtype
  700. logger.info(f"Will use dtype={dtype} as defined in model's config object")
  701. else:
  702. if is_sharded and "dtype" in sharded_metadata:
  703. dtype = sharded_metadata["dtype"]
  704. elif state_dict is not None:
  705. dtype = get_state_dict_dtype(state_dict)
  706. else:
  707. state_dict = load_state_dict(
  708. checkpoint_files[0], map_location="meta", weights_only=weights_only
  709. )
  710. dtype = get_state_dict_dtype(state_dict)
  711. logger.info(
  712. "Since the `dtype` attribute can't be found in model's config object, "
  713. "will use dtype={dtype} as derived from model's weights"
  714. )
  715. elif hasattr(torch, dtype):
  716. dtype = getattr(torch, dtype)
  717. else:
  718. raise ValueError(
  719. "`dtype` provided as a `str` can only be `'auto'`, or a string representation of a valid `torch.dtype`"
  720. )
  721. # cast it to a proper `torch.dtype` object
  722. dtype = getattr(torch, dtype) if isinstance(dtype, str) else dtype
  723. elif not isinstance(dtype, (dict, torch.dtype)):
  724. raise ValueError(
  725. f"`dtype` can be one of: `torch.dtype`, `'auto'`, a string of a valid `torch.dtype` or a `dict` with valid `dtype` "
  726. f"for each sub-config in composite configs, but received {dtype}"
  727. )
  728. else:
  729. # set torch.get_default_dtype() (usually fp32) as the default dtype if `None` is provided
  730. dtype = torch.get_default_dtype()
  731. if hf_quantizer is not None:
  732. dtype = hf_quantizer.update_dtype(dtype)
  733. # Get the main dtype
  734. if isinstance(dtype, dict):
  735. main_dtype = dtype.get("", torch.get_default_dtype())
  736. main_dtype = getattr(torch, main_dtype) if isinstance(main_dtype, str) else main_dtype
  737. logger.warning_once(
  738. "Using different dtypes per module is deprecated and will be removed in future versions "
  739. "Setting different dtypes per backbone model might cause device errors downstream, therefore "
  740. f"setting the dtype={main_dtype} for all modules."
  741. )
  742. else:
  743. main_dtype = dtype
  744. # Set it on the config and subconfigs
  745. config.dtype = main_dtype
  746. for sub_config_key in config.sub_configs:
  747. if (sub_config := getattr(config, sub_config_key)) is not None:
  748. sub_config.dtype = main_dtype
  749. return config, main_dtype
  750. class ModuleUtilsMixin:
  751. """
  752. A few utilities for `torch.nn.Modules`, to be used as a mixin.
  753. """
  754. @property
  755. def device(self) -> torch.device:
  756. """
  757. `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
  758. device).
  759. """
  760. return next(param.device for param in self.parameters())
  761. @property
  762. def dtype(self) -> torch.dtype:
  763. """
  764. `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
  765. """
  766. return next(param.dtype for param in self.parameters() if param.is_floating_point())
  767. def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
  768. """
  769. Invert an attention mask (e.g., switches 0. and 1.).
  770. Args:
  771. encoder_attention_mask (`torch.Tensor`): An attention mask.
  772. Returns:
  773. `torch.Tensor`: The inverted attention mask.
  774. """
  775. if encoder_attention_mask.dim() == 3:
  776. encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
  777. if encoder_attention_mask.dim() == 2:
  778. encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
  779. # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
  780. # encoder_extended_attention_mask = (encoder_extended_attention_mask ==
  781. # encoder_extended_attention_mask.transpose(-1, -2))
  782. encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
  783. encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * torch.finfo(self.dtype).min
  784. return encoder_extended_attention_mask
  785. @staticmethod
  786. def create_extended_attention_mask_for_decoder(input_shape, attention_mask):
  787. device = attention_mask.device
  788. batch_size, seq_length = input_shape
  789. seq_ids = torch.arange(seq_length, device=device)
  790. causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
  791. # in case past_key_values are used we need to add a prefix ones mask to the causal mask
  792. causal_mask = causal_mask.to(attention_mask.dtype)
  793. if causal_mask.shape[1] < attention_mask.shape[1]:
  794. prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
  795. causal_mask = torch.cat(
  796. [
  797. torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
  798. causal_mask,
  799. ],
  800. axis=-1,
  801. )
  802. extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
  803. return extended_attention_mask
  804. def get_extended_attention_mask(
  805. self,
  806. attention_mask: Tensor,
  807. input_shape: tuple[int, ...],
  808. dtype: torch.dtype | None = None,
  809. ) -> Tensor:
  810. """
  811. Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
  812. Arguments:
  813. attention_mask (`torch.Tensor`):
  814. Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
  815. input_shape (`tuple[int]`):
  816. The shape of the input to the model.
  817. Returns:
  818. `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
  819. """
  820. if dtype is None:
  821. dtype = self.dtype
  822. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  823. # ourselves in which case we just need to make it broadcastable to all heads.
  824. if attention_mask.dim() == 3:
  825. extended_attention_mask = attention_mask[:, None, :, :]
  826. elif attention_mask.dim() == 2:
  827. # Provided a padding mask of dimensions [batch_size, seq_length]
  828. # - if the model is a decoder, apply a causal mask in addition to the padding mask
  829. # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
  830. if getattr(self.config, "is_decoder", None):
  831. extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
  832. input_shape, attention_mask
  833. )
  834. else:
  835. extended_attention_mask = attention_mask[:, None, None, :]
  836. else:
  837. raise ValueError(
  838. f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
  839. )
  840. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  841. # masked positions, this operation will create a tensor which is 0.0 for
  842. # positions we want to attend and the dtype's smallest value for masked positions.
  843. # Since we are adding it to the raw scores before the softmax, this is
  844. # effectively the same as removing these entirely.
  845. extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
  846. extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
  847. return extended_attention_mask
  848. def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
  849. """
  850. Get number of (optionally, trainable or non-embeddings) parameters in the module.
  851. Args:
  852. only_trainable (`bool`, *optional*, defaults to `False`):
  853. Whether or not to return only the number of trainable parameters
  854. exclude_embeddings (`bool`, *optional*, defaults to `False`):
  855. Whether or not to return only the number of non-embeddings parameters
  856. Returns:
  857. `int`: The number of parameters.
  858. """
  859. if exclude_embeddings:
  860. embedding_param_names = [
  861. f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, nn.Embedding)
  862. ]
  863. is_loaded_in_4bit = getattr(self, "is_loaded_in_4bit", False)
  864. if is_loaded_in_4bit:
  865. import bitsandbytes as bnb
  866. total_params = 0
  867. for name, param in self.named_parameters():
  868. if exclude_embeddings and name in embedding_param_names:
  869. continue
  870. if param.requires_grad or not only_trainable:
  871. # For 4bit models, we need to multiply the number of parameters by 2 as half of the parameters are
  872. # used for the 4bit quantization (uint8 tensors are stored)
  873. if is_loaded_in_4bit and isinstance(param, bnb.nn.Params4bit):
  874. if hasattr(param, "element_size"):
  875. num_bytes = param.element_size()
  876. elif hasattr(param, "quant_storage"):
  877. num_bytes = param.quant_storage.itemsize
  878. else:
  879. num_bytes = 1
  880. total_params += param.numel() * 2 * num_bytes
  881. else:
  882. total_params += param.numel()
  883. return total_params
  884. class EmbeddingAccessMixin:
  885. """
  886. Base utilities to regroup getters and setters for embeddings.
  887. Introduces the `input_layer_embed` attribute, which indicates
  888. where the input embeddings come from and where they
  889. should be set.
  890. """
  891. _input_embed_layer = "embed_tokens" # default layer that holds input embeddings.
  892. def get_input_embeddings(self) -> nn.Module:
  893. """
  894. Returns the model's input embeddings.
  895. Returns:
  896. `nn.Module`: A torch module mapping vocabulary to hidden states.
  897. """
  898. name = getattr(self, "_input_embed_layer", "embed_tokens")
  899. # 1) Direct attribute (most NLP models).
  900. if (default_embedding := getattr(self, name, None)) is not None:
  901. return default_embedding
  902. # 2) Nested embeddings (e.g., self.embeddings.patch_embedding for vision/audio models).
  903. if hasattr(self, "embeddings") and hasattr(self.embeddings, name):
  904. return getattr(self.embeddings, name)
  905. # 3) Encoder/decoder wrappers (e.g., `self.model.embed_tokens` or similar overrides).
  906. if hasattr(self, "model") and hasattr(self.model, name):
  907. return getattr(self.model, name)
  908. if hasattr(self, "base_model"):
  909. base_model = self.base_model
  910. if base_model is not None and base_model is not self:
  911. return base_model.get_input_embeddings()
  912. raise NotImplementedError(
  913. f"`get_input_embeddings` not auto‑handled for {self.__class__.__name__}; please override in the subclass."
  914. )
  915. def set_input_embeddings(self, value: nn.Module):
  916. """Fallback setter that handles **~70%** of models in the code-base.
  917. Order of attempts:
  918. 1. `self.<_input_embed_layer>` (direct attribute)
  919. 2. `self.embeddings.<_input_embed_layer>` (nested embeddings for vision/audio models)
  920. 3. `self.model.<_input_embed_layer>` (encoder/decoder models)
  921. 4. delegate to the *base model* if one exists
  922. 5. otherwise raise `NotImplementedError` so subclasses still can (and
  923. should) override for exotic layouts.
  924. """
  925. name = getattr(self, "_input_embed_layer", "embed_tokens")
  926. # 1) Direct attribute (most NLP models)
  927. if hasattr(self, name):
  928. setattr(self, name, value)
  929. # 2) Nested embeddings (e.g., self.embeddings.patch_embedding for vision models)
  930. elif hasattr(self, "embeddings") and hasattr(self.embeddings, name):
  931. setattr(self.embeddings, name, value)
  932. # 3) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
  933. elif hasattr(self, "model") and hasattr(self.model, name):
  934. setattr(self.model, name, value)
  935. # 4) recurse once into the registered *base* model (e.g. for encoder/decoder)
  936. elif hasattr(self, "base_model") and self.base_model is not self:
  937. self.base_model.set_input_embeddings(value)
  938. else:
  939. raise NotImplementedError(
  940. f"`set_input_embeddings` not auto‑handled for {self.__class__.__name__}; please override in the subclass."
  941. )
  942. def get_output_embeddings(self):
  943. if not hasattr(self, "lm_head"):
  944. return None
  945. try:
  946. # Speech / vision backbones raise here, so we return None.
  947. # Legit use of get_input_embs?
  948. self.get_input_embeddings()
  949. except NotImplementedError:
  950. return None
  951. return self.lm_head
  952. def set_output_embeddings(self, new_embeddings):
  953. """
  954. Sets the model's output embedding, defaulting to setting new_embeddings to lm_head.
  955. """
  956. if getattr(self, "lm_head"):
  957. self.lm_head = new_embeddings
  958. class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMixin):
  959. r"""
  960. Base class for all models.
  961. [`PreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading,
  962. downloading and saving models as well as a few methods common to all models to:
  963. - resize the input embeddings
  964. Class attributes (overridden by derived classes):
  965. - **config_class** ([`PreTrainedConfig`]) -- A subclass of [`PreTrainedConfig`] to use as configuration class
  966. for this model architecture.
  967. - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
  968. classes of the same architecture adding modules on top of the base model.
  969. - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
  970. models, `pixel_values` for vision models and `input_values` for speech models).
  971. - **can_record_outputs** (dict):
  972. """
  973. # General model properties
  974. config_class: type[PreTrainedConfig] | None = None
  975. _auto_class = None
  976. base_model_prefix: str = ""
  977. _is_stateful: bool = False
  978. model_tags: list[str] | None = None
  979. # Input-related properties
  980. main_input_name: str = "input_ids"
  981. # Attributes used mainly in multimodal LLMs, though all models contain a valid field for these
  982. # Possible values are: text, image, video, audio and time
  983. input_modalities: str | list[str] = "text"
  984. # Device-map related properties
  985. _no_split_modules: set[str] | list[str] | None = None
  986. _skip_keys_device_placement: str | list[str] | None = None
  987. # Specific dtype upcasting
  988. # `_keep_in_fp32_modules` will upcast to fp32 only if the requested dtype is fp16
  989. # `_keep_in_fp32_modules_strict` will upcast to fp32 independently if the requested dtype is fp16 or bf16
  990. _keep_in_fp32_modules: set[str] | list[str] | None = None
  991. _keep_in_fp32_modules_strict: set[str] | list[str] | None = None
  992. # Loading-specific properties
  993. # A dictionary `{"target": "source"}` of checkpoint keys that are potentially tied to one another
  994. _tied_weights_keys: dict[str, str] = None
  995. # Used for BC support in VLMs, not meant to be used by new models
  996. _checkpoint_conversion_mapping: dict[str, str] = {}
  997. # A list of `re` patterns describing keys to ignore if they are missing from checkpoints to avoid warnings
  998. _keys_to_ignore_on_load_missing: list[str] | None = None
  999. # A list of `re` patterns describing keys to ignore if they are unexpected in the checkpoints to avoid warnings
  1000. _keys_to_ignore_on_load_unexpected: list[str] | None = None
  1001. # A list of keys to ignore when saving the model
  1002. _keys_to_ignore_on_save: list[str] | None = None
  1003. # Attention interfaces support properties
  1004. _supports_sdpa: bool = False
  1005. _supports_flash_attn: bool = False
  1006. _supports_flex_attn: bool = False
  1007. # Model's compatible flash kernels (e.g., "kernels-community/flash-mla") defaulting to the first in the list
  1008. _compatible_flash_implementations: list[str] | None = None
  1009. # Tensor-parallelism-related properties
  1010. # A tensor parallel plan of the form `{"model.layer.mlp.param": "colwise"}` to be applied to the model when TP is enabled.
  1011. # For top-level models, this attribute is currently defined in respective model code. For base models, this attribute comes
  1012. # from `config.base_model_tp_plan` during `post_init`.
  1013. _tp_plan: dict[str, str] = None
  1014. # Tensor parallel degree to which model is sharded to
  1015. _tp_size = None
  1016. # A pipeline parallel plan specifying the layers which may not be present on all ranks when PP is enabled. For top-level
  1017. # models, this attribute is currently defined in respective model code. For base models, it comes from
  1018. # `config.base_model_pp_plan` during `post_init`.
  1019. _pp_plan: dict[str, tuple[str, str]] = None
  1020. # Advanced functionalities support
  1021. supports_gradient_checkpointing: bool = False
  1022. _can_compile_fullgraph: bool = False
  1023. # This flag signal that the model can be used as an efficient backend in TGI and vLLM
  1024. # In practice, it means that they support attention (mask) interface functions, fully pass the kwargs
  1025. # through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan
  1026. _supports_attention_backend: bool = False
  1027. # A mapping describing what outputs can be captured by `capture_outputs` decorator during the forward pass
  1028. _can_record_outputs: dict | None = None
  1029. @property
  1030. @torch.compiler.allow_in_graph
  1031. def can_record_outputs(self) -> dict[str, OutputRecorder]:
  1032. """
  1033. Maps output names (e.g., "attentions", "hidden_states")
  1034. to either:
  1035. - A module class (e.g., `LlamaDecoderLayer`), using default index conventions:
  1036. * index=0 for "hidden_states"
  1037. * index=1 for "attentions"
  1038. - Or an `OutputRecorder(...)` with `target_class`, optional `index`, and `layer_name`.
  1039. Examples:
  1040. These two are equivalent:
  1041. ```python
  1042. _can_record_outputs = {
  1043. "attentions": LlamaAttention,
  1044. "hidden_states": LlamaDecoderLayer
  1045. }
  1046. _can_record_outputs = {
  1047. "attentions": OutputRecorder(LlamaAttention, index=1),
  1048. "hidden_states": OutputRecorder(LlamaDecoderLayer, index=0)
  1049. }
  1050. ```
  1051. This means you can record outputs from the same class, by specifying a layer name. Before
  1052. collecting outputs, we check that they come from this layer.
  1053. If you have cross attention that come from `LlamaAttention` and self attention that also
  1054. come from `LlamaAttention` but from `self_attn` you can do this:
  1055. ```python
  1056. class LlamaModel(PreTrainedModel):
  1057. _can_record_outputs = {
  1058. "attentions": OutputRecorder(LlamaAttention, index=1, layer-name="self_attn"),
  1059. "cross_attentions": OutputRecorder(LlamaAttention, index=1, layer_name="cross_attn")
  1060. }
  1061. ```
  1062. """
  1063. return self._can_record_outputs or {}
  1064. @property
  1065. def dummy_inputs(self) -> dict[str, torch.Tensor]:
  1066. """
  1067. `dict[str, torch.Tensor]`: Dummy inputs to do a forward pass in the network.
  1068. """
  1069. return {"input_ids": torch.tensor(DUMMY_INPUTS)}
  1070. def __init_subclass__(cls, **kwargs):
  1071. super().__init_subclass__(**kwargs)
  1072. # For BC we keep the original `config_class` definition in case
  1073. # there is a `config_class` attribute (e.g. remote code models),
  1074. # otherwise we derive it from the annotated `config` attribute.
  1075. # defined in this particular subclass
  1076. child_annotation = inspect.get_annotations(cls).get("config", None)
  1077. child_attribute = cls.__dict__.get("config_class", None)
  1078. # defined in the class (this subclass or any parent class)
  1079. full_annotation = get_type_hints(cls).get("config", None)
  1080. full_attribute = cls.config_class
  1081. # priority (child class_config -> child annotation -> global class_config -> global annotation)
  1082. if child_attribute is not None:
  1083. cls.config_class = child_attribute
  1084. elif child_annotation is not None:
  1085. cls.config_class = child_annotation
  1086. elif full_attribute is not None:
  1087. cls.config_class = full_attribute
  1088. elif full_annotation is not None:
  1089. cls.config_class = full_annotation
  1090. def __init__(self, config: PreTrainedConfig, *inputs, **kwargs):
  1091. super().__init__()
  1092. if not isinstance(config, PreTrainedConfig):
  1093. raise TypeError(
  1094. f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
  1095. "`PreTrainedConfig`. To create a model from a pretrained model use "
  1096. f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
  1097. )
  1098. self.config = config
  1099. self.name_or_path = config.name_or_path
  1100. # Check the attention implementation is supported, or set it if not yet set (on the internal attr, to avoid
  1101. # setting it recursively)
  1102. self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
  1103. self.config._attn_implementation,
  1104. is_init_check=True,
  1105. # We need to use this constant that is set through context manager as it cannot be forwarded in the model's __init__
  1106. allow_all_kernels=hub_kernels.ALLOW_ALL_KERNELS,
  1107. )
  1108. # Check the experts implementation is supported, or set it if not yet set (on the internal attr, to avoid
  1109. # setting it recursively)
  1110. self.config._experts_implementation_internal = self._check_and_adjust_experts_implementation(
  1111. self.config._experts_implementation
  1112. )
  1113. if self.can_generate():
  1114. self.generation_config = GenerationConfig.from_model_config(config)
  1115. # for initialization of the loss
  1116. loss_type = self.__class__.__name__
  1117. if loss_type not in LOSS_MAPPING:
  1118. loss_groups = f"({'|'.join(LOSS_MAPPING)})"
  1119. loss_type = re.findall(loss_groups, self.__class__.__name__)
  1120. if len(loss_type) > 0:
  1121. loss_type = loss_type[0]
  1122. else:
  1123. loss_type = None
  1124. self.loss_type = loss_type
  1125. _CAN_RECORD_REGISTRY[str(self.__class__)] = self._can_record_outputs # added for executorch support only
  1126. def post_init(self):
  1127. """
  1128. A method executed at the end of each Transformer model initialization, to execute code that needs the model's
  1129. modules properly initialized (such as weight initialization).
  1130. It is also used to obtain all correct static properties (parallelism plans, tied_weights_keys, _keep_in_fp32_modules, etc)
  1131. correctly in the case of composite models (that is, the top level model should know about those properties from its children).
  1132. """
  1133. # Attach the different parallel plans and tied weight keys to the top-most model, so that everything is
  1134. # easily available
  1135. self._tp_plan, self._ep_plan, self._pp_plan = {}, {}, {}
  1136. # If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
  1137. if self.base_model is self:
  1138. self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else {}
  1139. self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
  1140. self._ep_plan = self.config.base_model_ep_plan.copy() if self.config.base_model_ep_plan is not None else {}
  1141. # Current submodel should register its tied weights
  1142. self.all_tied_weights_keys = self.get_expanded_tied_weights_keys(all_submodels=False)
  1143. # Current submodel should register its `_keep_in_fp32_modules`
  1144. self._keep_in_fp32_modules = set(self._keep_in_fp32_modules or [])
  1145. self._keep_in_fp32_modules_strict = set(self._keep_in_fp32_modules_strict or [])
  1146. # Current submodel must register its `_no_split_modules` as well
  1147. self._no_split_modules = set(self._no_split_modules or [])
  1148. # Iterate over children only: as the final model is created, this is enough to gather the properties from all submodels.
  1149. # This works because the way the `__init__` and `post_init` are called on all submodules is depth-first in the graph
  1150. for name, module in self.named_children():
  1151. # Parallel plans
  1152. if plan := getattr(module, "_ep_plan", None):
  1153. self._ep_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
  1154. if plan := getattr(module, "_tp_plan", None):
  1155. self._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
  1156. if plan := getattr(module, "_pp_plan", None):
  1157. self._pp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
  1158. # Always attach the keys of the children (if the children's config says to NOT tie, then it's empty)
  1159. if tied_keys := getattr(module, "all_tied_weights_keys", None):
  1160. self.all_tied_weights_keys.update({f"{name}.{k}": f"{name}.{v}" for k, v in tied_keys.copy().items()})
  1161. # Record keep_in_fp_32 modules from the children as well
  1162. if keep_fp32 := getattr(module, "_keep_in_fp32_modules", None):
  1163. self._keep_in_fp32_modules.update(keep_fp32)
  1164. if keep_fp32_strict := getattr(module, "_keep_in_fp32_modules_strict", None):
  1165. self._keep_in_fp32_modules_strict.update(keep_fp32_strict)
  1166. # Record `_no_split_modules` from the children
  1167. if no_split := getattr(module, "_no_split_modules", None):
  1168. self._no_split_modules.update(no_split)
  1169. # Maybe initialize the weights and tie the keys
  1170. self.init_weights()
  1171. self._backward_compatibility_gradient_checkpointing()
  1172. @property
  1173. def tp_plan(self) -> dict[str, str]:
  1174. """
  1175. The full tp plan for the model's modules
  1176. """
  1177. if hasattr(self.config, "distributed_config") and self.config.distributed_config.enable_expert_parallel:
  1178. return self._ep_plan
  1179. return self._tp_plan
  1180. @property
  1181. def pp_plan(self) -> dict[str, tuple[str, str]]:
  1182. return self._pp_plan
  1183. @tp_plan.setter
  1184. def tp_plan(self, plan: dict[str, str] | None):
  1185. if plan is None:
  1186. self._tp_plan = {}
  1187. return
  1188. if not isinstance(plan, dict):
  1189. raise ValueError("Can only set a dictionary as `tp_plan`")
  1190. # Ensure the styles are all valid
  1191. for layer_pattern, parallel_style in plan.items():
  1192. if parallel_style not in ALL_PARALLEL_STYLES:
  1193. raise ValueError(
  1194. f"Unsupported tensor parallel style '{parallel_style}' for layer '{layer_pattern}'. "
  1195. f"Supported styles are {list(ALL_PARALLEL_STYLES.keys())}"
  1196. )
  1197. # Validate that the layer patterns match existing model structure. We check this by getting all parameter
  1198. # names and seeing if any match the patterns
  1199. model_param_names = [name for name, _ in self.named_parameters()]
  1200. for layer_pattern in plan.keys():
  1201. # Convert pattern to regex (replace * with .*)
  1202. regex_pattern = layer_pattern.replace("*", r"\d+")
  1203. pattern_matched = False
  1204. for param_name in model_param_names:
  1205. if re.match(regex_pattern, param_name):
  1206. pattern_matched = True
  1207. break
  1208. if not pattern_matched:
  1209. warnings.warn(
  1210. f"Layer pattern '{layer_pattern}' does not match any parameters in the model. This rule may not "
  1211. "be applied during tensor parallelization, or may lead to dimension mismatches"
  1212. )
  1213. # Set the plan
  1214. self._tp_plan = plan
  1215. @pp_plan.setter
  1216. def pp_plan(self, plan: dict[str, tuple[str, str]] | None):
  1217. if plan is None:
  1218. self._pp_plan = {}
  1219. return
  1220. if not isinstance(plan, dict):
  1221. raise ValueError("Can only set a dictionary as `pp_plan`")
  1222. self._pp_plan = plan
  1223. def dequantize(self, dtype=None):
  1224. """
  1225. Potentially dequantize the model in case it has been quantized by a quantization method that support
  1226. dequantization.
  1227. """
  1228. hf_quantizer = getattr(self, "hf_quantizer", None)
  1229. if hf_quantizer is None:
  1230. raise ValueError("You need to first quantize your model in order to dequantize it")
  1231. return hf_quantizer.dequantize(self, dtype=dtype)
  1232. def _backward_compatibility_gradient_checkpointing(self):
  1233. if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False):
  1234. self.gradient_checkpointing_enable()
  1235. # Remove the attribute now that is has been consumed, so it's no saved in the config.
  1236. delattr(self.config, "gradient_checkpointing")
  1237. def add_model_tags(self, tags: list[str] | str) -> None:
  1238. r"""
  1239. Add custom tags into the model that gets pushed to the Hugging Face Hub. Will
  1240. not overwrite existing tags in the model.
  1241. Args:
  1242. tags (`Union[list[str], str]`):
  1243. The desired tags to inject in the model
  1244. Examples:
  1245. ```python
  1246. from transformers import AutoModel
  1247. model = AutoModel.from_pretrained("google-bert/bert-base-cased")
  1248. model.add_model_tags(["custom", "custom-bert"])
  1249. # Push the model to your namespace with the name "my-custom-bert".
  1250. model.push_to_hub("my-custom-bert")
  1251. ```
  1252. """
  1253. if isinstance(tags, str):
  1254. tags = [tags]
  1255. if self.model_tags is None:
  1256. self.model_tags = []
  1257. for tag in tags:
  1258. if tag not in self.model_tags:
  1259. self.model_tags.append(tag)
  1260. @classmethod
  1261. def _from_config(cls, config, **kwargs):
  1262. """
  1263. All context managers that the model should be initialized under go here.
  1264. Args:
  1265. dtype (`torch.dtype`, *optional*):
  1266. Override the default `dtype` and load the model under this dtype.
  1267. """
  1268. # For BC on the old `torch_dtype`
  1269. dtype = kwargs.pop("dtype", config.dtype)
  1270. if (torch_dtype := kwargs.pop("torch_dtype", None)) is not None:
  1271. logger.warning_once("`torch_dtype` is deprecated! Use `dtype` instead!")
  1272. # if both kwargs are provided, use `dtype`
  1273. dtype = dtype if dtype != config.dtype else torch_dtype
  1274. if isinstance(dtype, str):
  1275. dtype = getattr(torch, dtype)
  1276. # Set the same `dtype` on all subconfigs to avoid dtype mismatch. When "auto" dtype
  1277. # with nested models, we can't dispatch different dtype per backbone module
  1278. for sub_config_key in config.sub_configs:
  1279. if (sub_config := getattr(config, sub_config_key)) is not None:
  1280. sub_config.dtype = dtype
  1281. # If passing `attn_implementation` as kwargs, respect it (it will be applied recursively on subconfigs)
  1282. if "attn_implementation" in kwargs:
  1283. config._attn_implementation = kwargs.pop("attn_implementation")
  1284. # If passing `experts_implementation` as kwargs, respect it (it will be applied recursively on subconfigs)
  1285. if "experts_implementation" in kwargs:
  1286. config._experts_implementation = kwargs.pop("experts_implementation")
  1287. # Needed if the attn_implementation is an outside `kernels-community` kernel
  1288. allow_all_kernels = kwargs.get("allow_all_kernels", False)
  1289. init_contexts = [apply_patches()]
  1290. if dtype is not None:
  1291. init_contexts.append(local_torch_dtype(dtype, cls.__name__))
  1292. if allow_all_kernels:
  1293. init_contexts.append(allow_all_hub_kernels())
  1294. needs_zero3_init = is_deepspeed_zero3_enabled() and not _is_quantized and not _is_ds_init_called
  1295. if needs_zero3_init:
  1296. logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
  1297. # this immediately partitions the model across all gpus, to avoid the overhead in time
  1298. # and memory copying it on CPU or each GPU first
  1299. import deepspeed
  1300. init_contexts.extend(
  1301. [
  1302. init.no_init_weights(),
  1303. deepspeed.zero.Init(config_dict_or_path=deepspeed_config()),
  1304. set_zero3_state(),
  1305. ]
  1306. )
  1307. # Instantiate the model
  1308. with ContextManagers(init_contexts):
  1309. model = cls(config, **kwargs)
  1310. patch_output_recorders(model)
  1311. # Under ZeRO-3, parameters were partitioned into empty tensors during construction,
  1312. # so weight init was suppressed. Re-initialize using the ZeRO-3 variant which gathers
  1313. # each module's parameters before init to avoid OOM on large models.
  1314. if needs_zero3_init:
  1315. from .integrations.deepspeed import initialize_weights_zero3
  1316. initialize_weights_zero3(model)
  1317. model.tie_weights()
  1318. return model
  1319. @property
  1320. def base_model(self) -> nn.Module:
  1321. """
  1322. `torch.nn.Module`: The main body of the model.
  1323. """
  1324. return getattr(self, self.base_model_prefix, self)
  1325. @classmethod
  1326. def can_generate(cls) -> bool:
  1327. """
  1328. Returns whether this model can generate sequences with `.generate()` from the `GenerationMixin`.
  1329. Under the hood, on classes where this function returns True, some generation-specific changes are triggered:
  1330. for instance, the model instance will have a populated `generation_config` attribute.
  1331. Returns:
  1332. `bool`: Whether this model can generate sequences with `.generate()`.
  1333. """
  1334. # Directly inherits `GenerationMixin` -> can generate
  1335. if "GenerationMixin" in str(cls.__bases__):
  1336. return True
  1337. # The class inherits from a class that can generate (recursive check) -> can generate
  1338. for base in cls.__bases__:
  1339. if not hasattr(base, "can_generate"):
  1340. continue
  1341. if "PreTrainedModel" not in str(base) and base.can_generate():
  1342. return True
  1343. # Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this
  1344. # was how we detected whether a model could generate.
  1345. if hasattr(cls, "prepare_inputs_for_generation"): # implicit: doesn't inherit `GenerationMixin`
  1346. logger.warning(
  1347. f"{cls.__name__} has generative capabilities, as `prepare_inputs_for_generation` is explicitly "
  1348. "defined. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, "
  1349. "`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability "
  1350. "to call `generate` and other related functions."
  1351. "\n - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the "
  1352. "model with an auto class. See https://huggingface.co/docs/transformers/en/model_doc/auto#auto-classes"
  1353. "\n - If you are the owner of the model architecture code, please modify your model class such that "
  1354. "it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception)."
  1355. "\n - If you are not the owner of the model architecture class, please contact the model code owner "
  1356. "to update it."
  1357. )
  1358. # Otherwise, can't generate
  1359. return False
  1360. def _flash_attn_import_error(
  1361. self,
  1362. flash_attn_version: int,
  1363. general_availability_check: Callable,
  1364. pkg_availability_check: Callable,
  1365. supported_devices: tuple[tuple[Callable, str]],
  1366. custom_supported_devices: tuple[tuple[Callable, str]] = (),
  1367. cuda_min_major_version: int | None = None,
  1368. ):
  1369. """
  1370. Checks whether the specified Flash Attention version is supported and if not, searches for the specific reason
  1371. on why it failed - package import and/or device incompatibility issues.
  1372. Args:
  1373. flash_attn_version (`int`):
  1374. The requested version of Flash Attention.
  1375. general_availability_check (`Callable`):
  1376. Checks whether our `is_available` function detects the specific FA version. Failing reasons
  1377. are then checked for one-by-one.
  1378. pkg_availability_check (`Callable`):
  1379. Checks whether the package could theoretically be detected in the environment by the init structures.
  1380. This is not a sure-fire check as device compatibility with FA is just as important.
  1381. supported_devices (`tuple[tuple[Callable, str]]`):
  1382. Essentially a list (for mutable kwargs reasons a tuple) of the supported devices in the format of
  1383. `(device_availability_check, device_name)`, i.e. a pair of the associated device's name and whether
  1384. it is available in the environment.
  1385. custom_supported_devices (`tuple[tuple[Callable, str]]`, *optional*, defaults to `()`):
  1386. Essentially a list (for mutable kwargs reasons a tuple) of the custom supported devices in the format of
  1387. `(device_availability_check, info_message)`. These custom devices have custom logic outside the torch
  1388. ecosystem either via kernels or other packages and hence have early checks for availability.
  1389. cuda_min_major_version (`int`, *optional*):
  1390. The minimum major cuda version supported for this version of Flash Attention. This is mostly
  1391. affecting more recent versions which are more specialized to the features of new hardware.
  1392. """
  1393. # Certain devices have custom workarounds e.g. with their own package distribution (NPU) or via kernels (XPU)
  1394. for device_availability_check, info_message in custom_supported_devices:
  1395. if device_availability_check():
  1396. logger.info(info_message)
  1397. return
  1398. if not general_availability_check():
  1399. preface = f"FlashAttention{flash_attn_version} has been toggled on, but it cannot be used due to the following error:"
  1400. # Can the package be seen in the import structure
  1401. if not pkg_availability_check():
  1402. raise ImportError(
  1403. f"{preface} the package for FlashAttention{flash_attn_version} doesn't seem to be installed."
  1404. )
  1405. # Minimum version (FA2 only)
  1406. elif flash_attn_version == 2 and not is_flash_attn_greater_or_equal("2.3.3"):
  1407. raise ImportError(f"{preface} FlashAttention{flash_attn_version} requires at least version `2.3.3`.")
  1408. else:
  1409. # Supported devices availability
  1410. device_availability_checks, device_names = zip(*supported_devices)
  1411. if not any(device_availability_check() for device_availability_check in device_availability_checks):
  1412. raise ImportError(
  1413. f"{preface} FlashAttention{flash_attn_version} is not available on CPU. Please make sure you are on any of the supported devices: {device_names}."
  1414. )
  1415. # Cuda major versions (more recent FA versions are specialized to newer cuda devices)
  1416. elif cuda_min_major_version is not None and is_torch_cuda_available():
  1417. major, _ = torch.cuda.get_device_capability()
  1418. if major < cuda_min_major_version:
  1419. raise ImportError(
  1420. f"{preface} FlashAttention{flash_attn_version} requires compute capability >= {cuda_min_major_version}, but found {torch.cuda.get_device_capability()} with compute capability {major}.x"
  1421. )
  1422. def _flash_attn_can_dispatch(self, flash_attn_version: int, is_init_check: bool = False) -> bool:
  1423. """
  1424. Check the availability of Flash Attention for a given model.
  1425. Args:
  1426. flash_attn_version (`int`):
  1427. The requested version of Flash Attention.
  1428. is_init_check (`bool`, *optional*):
  1429. Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are
  1430. fully instantiated. This is needed as we also check the devices of the weights, which are only available
  1431. later after __init__. This allows to raise proper exceptions early before instantiating the full models
  1432. if we know that the model does not support the requested attention.
  1433. """
  1434. if not self._supports_flash_attn:
  1435. raise ValueError(
  1436. f"{self.__class__.__name__} does not support Flash Attention {flash_attn_version} yet. Please request to add support where"
  1437. f" the model is hosted, on its model hub page: https://huggingface.co/{self.config._name_or_path}/discussions/new"
  1438. " or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new"
  1439. )
  1440. if flash_attn_version not in [2, 3, 4]:
  1441. raise ValueError(f"Requested Flash Attention {flash_attn_version} which is not supported.")
  1442. # Check if we can even use the FA version based on the env of the user
  1443. self._flash_attn_import_error(**FLASH_ATTENTION_COMPATIBILITY_MATRIX[flash_attn_version])
  1444. # Check for attention dropout, which is incompatible with newer FA versions
  1445. # (many should not really care about dropout as it is not super effective, hence warning for now)
  1446. if flash_attn_version > 2:
  1447. if hasattr(self.config, "attention_dropout") and self.config.attention_dropout > 0:
  1448. logger.warning_once(
  1449. f"You are attempting to use Flash Attention {flash_attn_version} with dropout. "
  1450. "This might lead to unexpected behaviour as this is not supported on recent versions of Flash Attention."
  1451. )
  1452. # People often move dtypes after init so we only warn in those cases
  1453. dtype = self.config.dtype
  1454. if dtype is None:
  1455. logger.warning_once(
  1456. f"You are attempting to use Flash Attention {flash_attn_version} without specifying a dtype. This might lead to unexpected behaviour"
  1457. )
  1458. elif dtype is not None and dtype not in [torch.float16, torch.bfloat16]:
  1459. logger.warning_once(
  1460. f"Flash Attention {flash_attn_version} only supports torch.float16 and torch.bfloat16 dtypes, but"
  1461. f" the current dype in {self.__class__.__name__} is {dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator,"
  1462. f' or load the model with the `dtype` argument. Example: `model = AutoModel.from_pretrained("meta-llama/Llama-3.2-1B", attn_implementation="flash_attention_{flash_attn_version}", dtype=torch.float16)`'
  1463. )
  1464. # With the early check, the parameters are not yet initialized correctly
  1465. if not is_init_check:
  1466. param_devices = list({param.device for param in self.parameters()})
  1467. if len(param_devices) == 1 and param_devices[0].type == "cpu":
  1468. found_device = False
  1469. for device_availability_check, device_name in FLASH_ATTENTION_COMPATIBILITY_MATRIX[flash_attn_version][
  1470. "supported_devices"
  1471. ]:
  1472. if device_availability_check():
  1473. found_device = True
  1474. logger.warning_once(
  1475. f"You are attempting to use Flash Attention {flash_attn_version} with a model not initialized on GPU. Please make sure to have "
  1476. "access to a GPU and either initialise the model on a GPU by passing a device_map or initialising the model on CPU and then "
  1477. f"moving it to GPU, e.g. with `model.to('{device_name}')`."
  1478. )
  1479. break
  1480. if not found_device:
  1481. raise ValueError(
  1482. f"You are attempting to use Flash Attention {flash_attn_version} with a model not initialized on GPU and with no GPU available. "
  1483. "This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map "
  1484. "or initialising the model on CPU and then moving it to GPU."
  1485. )
  1486. # If no error raise by this point, we can return `True`
  1487. return True
  1488. def _sdpa_can_dispatch(self, is_init_check: bool = False) -> bool:
  1489. """
  1490. Check the availability of SDPA for a given model.
  1491. Args:
  1492. is_init_check (`bool`, *optional*):
  1493. Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are
  1494. fully instantiated. This is needed as we also check the devices of the weights, which are only available
  1495. later after __init__. This allows to raise proper exceptions early before instantiating the full models
  1496. if we know that the model does not support the requested attention.
  1497. """
  1498. if not self._supports_sdpa:
  1499. raise ValueError(
  1500. f"{self.__class__.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet."
  1501. " Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe"
  1502. ' this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`'
  1503. )
  1504. if (
  1505. torch.version.hip is not None
  1506. and torch.cuda.device_count() > 1
  1507. and version.parse(torch.__version__) < version.parse("2.4.1")
  1508. ):
  1509. logger.warning_once(
  1510. "Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
  1511. )
  1512. torch.backends.cuda.enable_flash_sdp(False)
  1513. return True
  1514. def _grouped_mm_can_dispatch(self) -> bool:
  1515. """
  1516. Check the availability of Grouped MM for a given model.
  1517. """
  1518. if not self._can_set_experts_implementation():
  1519. raise ValueError(f"{self.__class__.__name__} does not support setting experts implementation.")
  1520. # If no error raised by this point, we can return `True`
  1521. return True
  1522. def _flex_attn_can_dispatch(self, is_init_check: bool = False) -> bool:
  1523. """
  1524. Check the availability of Flex Attention for a given model.
  1525. Args:
  1526. is_init_check (`bool`, *optional*):
  1527. Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are
  1528. fully instantiated. This is needed as we also check the devices of the weights, which are only available
  1529. later after __init__. This allows to raise proper exceptions early before instantiating the full models
  1530. if we know that the model does not support the requested attention.
  1531. """
  1532. if not self._supports_flex_attn:
  1533. raise ValueError(
  1534. f"{self.__class__.__name__} does not support an attention implementation through torch's flex_attention."
  1535. " Please request the support for this architecture: https://github.com/huggingface/transformers/issues/34809."
  1536. " If you believe this error is a bug, please open an issue in Transformers GitHub repository"
  1537. ' and load your model with the argument `attn_implementation="eager"` meanwhile.'
  1538. ' Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`'
  1539. )
  1540. if not is_torch_flex_attn_available():
  1541. raise ImportError(
  1542. "PyTorch Flex Attention requirements in Transformers are not met. Please install torch>=2.5.0."
  1543. )
  1544. # If no error raise by this point, we can return `True`
  1545. return True
  1546. def _check_and_adjust_attn_implementation(
  1547. self, attn_implementation: str | None, is_init_check: bool = False, allow_all_kernels: bool = False
  1548. ) -> str:
  1549. """
  1550. Check that the `attn_implementation` exists and is supported by the models, and try to get the kernel from hub if
  1551. it matches hf kernels pattern.
  1552. Args:
  1553. attn_implementation (`str` or `None`):
  1554. The attention implementation to check for existence/validity.
  1555. is_init_check (`bool`, *optional*):
  1556. Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are
  1557. fully instantiated. This is needed as we also check the devices of the weights, which are only available
  1558. later after __init__. This allows to raise proper exceptions early before instantiating the full models
  1559. if we know that the model does not support the requested attention.
  1560. allow_all_kernels (`bool`, optional):
  1561. Whether to load kernels from unverified hub repos, if `attn_implementation` is a custom kernel outside
  1562. of the `kernels-community` hub repository.
  1563. Returns:
  1564. `str`: The final attention implementation to use, including potential fallbacks from sdpa to eager, or from
  1565. None to sdpa (to potentially eager).
  1566. """
  1567. # Auto-correct model's default flash implementation if specified
  1568. if attn_implementation is not None:
  1569. is_paged = attn_implementation.startswith("paged|")
  1570. base_implementation = attn_implementation.removeprefix("paged|")
  1571. compatible_flash_implementations = getattr(self, "_compatible_flash_implementations", None)
  1572. if (
  1573. is_flash_attention_requested(requested_attention_implementation=base_implementation)
  1574. and compatible_flash_implementations is not None
  1575. and base_implementation not in compatible_flash_implementations
  1576. ):
  1577. default_flash_implementation = (
  1578. f"paged|{compatible_flash_implementations[0]}" if is_paged else compatible_flash_implementations[0]
  1579. )
  1580. logger.warning_once(
  1581. f"This model is compatible with the following flash attention implementations: `{compatible_flash_implementations}`. "
  1582. f"Automatically falling back to `{default_flash_implementation}` instead of `{attn_implementation}`."
  1583. )
  1584. attn_implementation = default_flash_implementation
  1585. applicable_attn_implementation = attn_implementation
  1586. is_paged = attn_implementation is not None and attn_implementation.startswith("paged|")
  1587. requested_original_flash_attn = False
  1588. if is_flash_attention_requested(requested_attention_implementation=attn_implementation):
  1589. # If FA not installed, do not fail but use kernels instead if possible
  1590. for fa_version in FLASH_ATTENTION_COMPATIBILITY_MATRIX.keys():
  1591. # Check whether we have an original FA requested but not available in the env
  1592. if requested_original_flash_attn := (
  1593. attn_implementation.removeprefix("paged|") == f"flash_attention_{fa_version}"
  1594. and not FLASH_ATTENTION_COMPATIBILITY_MATRIX[fa_version]["general_availability_check"]()
  1595. ):
  1596. break
  1597. if (
  1598. self._supports_flash_attn
  1599. and requested_original_flash_attn
  1600. and is_kernels_available()
  1601. and not is_torch_npu_available()
  1602. ):
  1603. applicable_attn_implementation = FLASH_ATTN_KERNEL_FALLBACK[attn_implementation.removeprefix("paged|")]
  1604. if is_torch_xpu_available() and attn_implementation.removeprefix("paged|") == "flash_attention_2":
  1605. # On XPU, kernels library is the native implementation
  1606. # Disabling this flag to avoid giving wrong fallbacks on errors and warnings
  1607. requested_original_flash_attn = False
  1608. if is_paged:
  1609. applicable_attn_implementation = f"paged|{applicable_attn_implementation}"
  1610. if is_kernel(applicable_attn_implementation):
  1611. try:
  1612. # preload flash attention here to allow compile with fullgraph
  1613. if is_paged:
  1614. lazy_import_paged_flash_attention(
  1615. applicable_attn_implementation, allow_all_kernels=allow_all_kernels
  1616. )
  1617. else:
  1618. lazy_import_flash_attention(applicable_attn_implementation, allow_all_kernels=allow_all_kernels)
  1619. # log that we used kernel fallback if successful
  1620. if requested_original_flash_attn:
  1621. logger.warning_once(
  1622. f"You do not have `flash_attn` installed, using `{applicable_attn_implementation}` "
  1623. "from the `kernels` library instead!"
  1624. )
  1625. except Exception as e:
  1626. # raise the proper exception for requested flash attention
  1627. if requested_original_flash_attn:
  1628. fa_version = int(attn_implementation[-1]) # "flash_attention_(2|3|...)"
  1629. self._flash_attn_can_dispatch(flash_attn_version=fa_version, is_init_check=is_init_check)
  1630. # error properly out if a kernel was specifically requested
  1631. raise e
  1632. else:
  1633. applicable_attn_implementation = self.get_correct_attn_implementation(
  1634. applicable_attn_implementation, is_init_check
  1635. )
  1636. # preload flash attention here to allow compile with fullgraph
  1637. if is_flash_attention_requested(requested_attention_implementation=applicable_attn_implementation):
  1638. lazy_import_flash_attention(applicable_attn_implementation)
  1639. return applicable_attn_implementation
  1640. def _check_and_adjust_experts_implementation(self, experts_implementation: str | None) -> str:
  1641. """
  1642. Check that the `experts_implementation` exists and is supported by the models.
  1643. Args:
  1644. experts_implementation (`str` or `None`):
  1645. The experts implementation to check for existence/validity.
  1646. Returns:
  1647. `str`: The final experts implementation to use.
  1648. """
  1649. applicable_experts_implementation = self.get_correct_experts_implementation(experts_implementation)
  1650. return applicable_experts_implementation
  1651. def get_correct_attn_implementation(self, requested_attention: str | None, is_init_check: bool = False) -> str:
  1652. applicable_attention = "sdpa" if requested_attention is None else requested_attention
  1653. if applicable_attention not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
  1654. message = (
  1655. f'Specified `attn_implementation="{applicable_attention}"` is not supported. The only possible arguments are '
  1656. '`attn_implementation="eager"`, `"paged|eager"`'
  1657. )
  1658. # check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases
  1659. if self._supports_flash_attn or getattr(self, "_supports_flash_attn_2", False):
  1660. message += ", "
  1661. for fa_version in FLASH_ATTENTION_COMPATIBILITY_MATRIX.keys():
  1662. message += f'`"attn_implementation=flash_attention_{fa_version}"`, `"attn_implementation=paged|flash_attention_{fa_version}"`, '
  1663. message = message[:-2] # remove trailing comma
  1664. if self._supports_sdpa:
  1665. message += ', `"attn_implementation=sdpa"`, `"attn_implementation=paged|sdpa"`'
  1666. if self._supports_flex_attn:
  1667. message += ', `"attn_implementation=flex_attention"`'
  1668. raise ValueError(message + ".")
  1669. # Perform relevant checks
  1670. if is_flash_attention_requested(requested_attention_implementation=applicable_attention) and (
  1671. fa_matched := re.search(r"^flash_attention_(\d)$", applicable_attention)
  1672. ):
  1673. fa_version = int(fa_matched.group(1)) # last digit
  1674. self._flash_attn_can_dispatch(flash_attn_version=fa_version, is_init_check=is_init_check)
  1675. elif "flex_attention" in applicable_attention:
  1676. self._flex_attn_can_dispatch(is_init_check)
  1677. elif "sdpa" in applicable_attention:
  1678. # Sdpa is the default, so we try it and fallback to eager otherwise when not possible
  1679. try:
  1680. self._sdpa_can_dispatch(is_init_check)
  1681. except (ValueError, ImportError) as e:
  1682. if requested_attention is not None and "sdpa" in requested_attention:
  1683. raise e
  1684. applicable_attention = "eager"
  1685. return applicable_attention
  1686. def get_correct_experts_implementation(self, requested_experts: str | None) -> str:
  1687. applicable_experts = "grouped_mm" if requested_experts is None else requested_experts
  1688. if applicable_experts not in ["eager", "grouped_mm", "batched_mm", "deepgemm"]:
  1689. message = (
  1690. f'Specified `experts_implementation="{applicable_experts}"` is not supported. The only possible arguments are '
  1691. '`experts_implementation="eager"`, `"experts_implementation=grouped_mm"`, `"experts_implementation=batched_mm"` '
  1692. 'and `"experts_implementation=deepgemm"`.'
  1693. )
  1694. raise ValueError(message)
  1695. # Perform relevant checks
  1696. if applicable_experts == "grouped_mm":
  1697. try:
  1698. self._grouped_mm_can_dispatch()
  1699. except (ValueError, ImportError) as e:
  1700. if requested_experts == "grouped_mm":
  1701. raise e
  1702. applicable_experts = "eager"
  1703. return applicable_experts
  1704. @classmethod
  1705. def _can_set_attn_implementation(cls) -> bool:
  1706. """Detect whether the class supports setting its attention implementation dynamically. It is an ugly check based on
  1707. opening the file, but avoids maintaining yet another property flag.
  1708. """
  1709. class_module = sys.modules[cls.__module__]
  1710. # This can happen for a custom model in a jupyter notebook or repl for example - simply do not allow to set it then
  1711. if not hasattr(class_module, "__file__"):
  1712. return False
  1713. class_file = class_module.__file__
  1714. with open(class_file, "r", encoding="utf-8") as f:
  1715. code = f.read()
  1716. # heuristic -> if we find those patterns, the model uses the correct interface
  1717. if re.search(r"class \w+Attention\(nn.Module\)", code):
  1718. return "eager_attention_forward" in code and "ALL_ATTENTION_FUNCTIONS.get_interface(" in code
  1719. else:
  1720. # If no attention layer, assume `True`. Most probably a multimodal model or inherits from existing models
  1721. return True
  1722. @classmethod
  1723. def _can_set_experts_implementation(cls) -> bool:
  1724. """Detect whether the class supports setting its experts implementation dynamically. It is an ugly check based on
  1725. opening the file, but avoids maintaining yet another property flag.
  1726. """
  1727. class_module = sys.modules[cls.__module__]
  1728. # This can happen for a custom model in a jupyter notebook or repl for example - simply do not allow to set it then
  1729. if not hasattr(class_module, "__file__"):
  1730. return False
  1731. class_file = class_module.__file__
  1732. with open(class_file, "r", encoding="utf-8") as f:
  1733. code = f.read()
  1734. # heuristic -> if we the use_experts_implementation decorator is used, then we can set it
  1735. return "@use_experts_implementation" in code
  1736. def set_attn_implementation(self, attn_implementation: str | dict, allow_all_kernels: bool = False):
  1737. """
  1738. Set the requested `attn_implementation` for this model.
  1739. Args:
  1740. attn_implementation (`str` or `dict`):
  1741. The attention implementation to set for this model. It can be either a `str`, in which case it will be
  1742. dispatched to all submodels if relevant, or a `dict` where keys are the sub_configs name, in which case each
  1743. submodel will dispatch the corresponding value.
  1744. allow_all_kernels (`bool`, optional):
  1745. Whether to load kernels from unverified hub repos, if `attn_implementation` is a custom kernel outside
  1746. of the `kernels-community` hub repository.
  1747. """
  1748. requested_implementation = (
  1749. attn_implementation
  1750. if not isinstance(attn_implementation, dict)
  1751. else attn_implementation.get("", self.config._attn_implementation)
  1752. )
  1753. if requested_implementation != self.config._attn_implementation:
  1754. # In this case, raise
  1755. if not self._can_set_attn_implementation():
  1756. logger.warning(
  1757. f"{self.__class__.__name__} does not support setting its attention implementation dynamically, because it "
  1758. "does not follow the functional approach based on AttentionInterface "
  1759. "(see https://huggingface.co/docs/transformers/en/attention_interface)"
  1760. )
  1761. else:
  1762. requested_implementation = self._check_and_adjust_attn_implementation(
  1763. requested_implementation, is_init_check=False, allow_all_kernels=allow_all_kernels
  1764. )
  1765. # Apply the change (on the internal attr, to avoid setting it recursively)
  1766. self.config._attn_implementation_internal = requested_implementation
  1767. # Apply it to all submodels as well
  1768. for submodule in self.modules():
  1769. # We found a submodel (which is not self) with a different config (otherwise, it may be the same "actual model",
  1770. # e.g. ForCausalLM has a Model inside, but no need to check it again)
  1771. if (
  1772. submodule is not self
  1773. and isinstance(submodule, PreTrainedModel)
  1774. and submodule.config.__class__ != self.config.__class__
  1775. # If it was already changed, no need to do it again
  1776. and not hasattr(submodule.config, "_attn_was_changed")
  1777. ):
  1778. # In this case, warn and skip
  1779. if not submodule._can_set_attn_implementation():
  1780. logger.warning(
  1781. f"{submodule.__class__.__name__} does not support setting its attention implementation dynamically, because it "
  1782. "does not follow the functional approach based on AttentionInterface "
  1783. "(see https://huggingface.co/docs/transformers/en/attention_interface)"
  1784. )
  1785. # Set the attn on the submodule
  1786. else:
  1787. sub_implementation = requested_implementation
  1788. if isinstance(attn_implementation, dict):
  1789. for subconfig_key in self.config.sub_configs:
  1790. # We need to check for exact object match here, with `is`
  1791. if getattr(self.config, subconfig_key) is submodule.config:
  1792. sub_implementation = attn_implementation.get(
  1793. subconfig_key, submodule.config._attn_implementation
  1794. )
  1795. break
  1796. # Check the module can use correctly, otherwise we raise an error if requested attention can't be set for submodule
  1797. sub_implementation = submodule.get_correct_attn_implementation(sub_implementation)
  1798. submodule.config._attn_implementation_internal = sub_implementation
  1799. # Still add it as "changed" even if it was skipped, as we would otherwise try to set it in the dark afterwards
  1800. # We need to set it on the config itself, to differentiate 2 subconfigs of the same __class__ potentially
  1801. submodule.config._attn_was_changed = True
  1802. # We need this as some old and badly designed models use subconfigs without declaring the corresponding modules as PreTrainedModel
  1803. for subconfig_key in self.config.sub_configs:
  1804. if (subconfig := getattr(self.config, subconfig_key)) is not None:
  1805. sub_implementation = (
  1806. requested_implementation
  1807. if not isinstance(attn_implementation, dict)
  1808. else attn_implementation.get(subconfig_key, subconfig._attn_implementation)
  1809. )
  1810. # This means we did not perform any check above for this particular subconfig -> set it in the dark if it is registered
  1811. if (
  1812. not hasattr(subconfig, "_attn_was_changed")
  1813. # If it's already the same, then no need to enter here and raise warnings
  1814. and sub_implementation != subconfig._attn_implementation
  1815. ):
  1816. if sub_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
  1817. raise ValueError(
  1818. f'Specified `attn_implementation="{sub_implementation}"` is not supported for {subconfig_key}. '
  1819. 'The only possible arguments are "eager" (manual attention implementation)'
  1820. f"or one of the following: {list(ALL_ATTENTION_FUNCTIONS.valid_keys())}"
  1821. )
  1822. subconfig._attn_implementation_internal = sub_implementation
  1823. logger.warning(
  1824. f"We set the attention implementation for the sub-config `{subconfig_key}` to `{sub_implementation}` "
  1825. "without finding the associated sub-model. For this reason we could not check if the model supports it. "
  1826. "You may encounter undefined behavior."
  1827. )
  1828. # Unset the attribute in this case, to avoid issues in the future
  1829. else:
  1830. if hasattr(subconfig, "_attn_was_changed"):
  1831. del subconfig._attn_was_changed
  1832. def set_experts_implementation(self, experts_implementation: str | dict):
  1833. """
  1834. Set the requested `experts_implementation` for this model.
  1835. Args:
  1836. experts_implementation (`str` or `dict`):
  1837. The experts implementation to set for this model. It can be either a `str`, in which case it will be
  1838. dispatched to all submodels if relevant, or a `dict` where keys are the sub_configs name, in which case each
  1839. submodel will dispatch the corresponding value.
  1840. """
  1841. requested_implementation = (
  1842. experts_implementation
  1843. if not isinstance(experts_implementation, dict)
  1844. else experts_implementation.get("", self.config._experts_implementation)
  1845. )
  1846. if requested_implementation != self.config._experts_implementation:
  1847. requested_implementation = self._check_and_adjust_experts_implementation(requested_implementation)
  1848. # Apply the change (on the internal attr, to avoid setting it recursively)
  1849. self.config._experts_implementation_internal = requested_implementation
  1850. # Apply it to all submodels as well
  1851. for submodule in self.modules():
  1852. # We found a submodel (which is not self) with a different config (otherwise, it may be the same "actual model",
  1853. # e.g. ForCausalLM has a Model inside, but no need to check it again)
  1854. if (
  1855. submodule is not self
  1856. and isinstance(submodule, PreTrainedModel)
  1857. and submodule.config.__class__ != self.config.__class__
  1858. ):
  1859. # Set the experts on the submodule
  1860. sub_implementation = requested_implementation
  1861. if isinstance(experts_implementation, dict):
  1862. for subconfig_key in self.config.sub_configs:
  1863. # We need to check for exact object match here, with `is`
  1864. if getattr(self.config, subconfig_key) is submodule.config:
  1865. sub_implementation = experts_implementation.get(
  1866. subconfig_key, submodule.config._experts_implementation
  1867. )
  1868. break
  1869. # Check the module can use correctly, otherwise we raise an error if requested experts can't be set for submodule
  1870. sub_implementation = submodule.get_correct_experts_implementation(sub_implementation)
  1871. submodule.config._experts_implementation_internal = sub_implementation
  1872. def enable_input_require_grads(self):
  1873. """
  1874. Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
  1875. the model weights fixed.
  1876. """
  1877. def make_inputs_require_grads(module, input, output):
  1878. output.requires_grad_(True)
  1879. hooks = []
  1880. seen_modules = set()
  1881. found_embeddings = False
  1882. for module in self.modules():
  1883. if not (isinstance(module, PreTrainedModel) and hasattr(module, "get_input_embeddings")):
  1884. continue
  1885. try:
  1886. input_embeddings = module.get_input_embeddings()
  1887. except NotImplementedError:
  1888. continue
  1889. if input_embeddings is None or not hasattr(input_embeddings, "register_forward_hook"):
  1890. continue
  1891. embedding_id = id(input_embeddings)
  1892. if embedding_id in seen_modules:
  1893. continue
  1894. seen_modules.add(embedding_id)
  1895. hooks.append(input_embeddings.register_forward_hook(make_inputs_require_grads))
  1896. found_embeddings = True
  1897. self._require_grads_hooks = hooks
  1898. if hooks:
  1899. # for BC
  1900. self._require_grads_hook = hooks[0]
  1901. if not found_embeddings:
  1902. logger.warning_once(
  1903. f"{self.__class__.__name__} does not expose input embeddings. Gradients cannot flow back to the token "
  1904. "embeddings when using adapters or gradient checkpointing. Override `get_input_embeddings` to fully "
  1905. "support those features, or set `_input_embed_layer` to the attribute name that holds the embeddings."
  1906. )
  1907. def disable_input_require_grads(self):
  1908. """
  1909. Removes the `_require_grads_hook`.
  1910. """
  1911. hooks = getattr(self, "_require_grads_hooks", None)
  1912. if not hooks:
  1913. return
  1914. for hook in hooks:
  1915. hook.remove()
  1916. self._require_grads_hooks = []
  1917. if hasattr(self, "_require_grads_hook"):
  1918. del self._require_grads_hook
  1919. def get_encoder(self, modality: str | None = None):
  1920. """
  1921. Best-effort lookup of the *encoder* module. If provided with `modality` argument,
  1922. it looks for a modality-specific encoder in multimodal models (e.g. "image_encoder")
  1923. By default the function returns model's text encoder if any, and otherwise returns `self`.
  1924. Possible `modality` values are "image", "video" and "audio".
  1925. """
  1926. # NOTE: new models need to use existing names for layers if possible, so this list doesn't grow infinitely
  1927. if modality in ["image", "video"]:
  1928. possible_module_names = ["vision_tower", "visual", "vision_model", "vision_encoder", "image_tower"]
  1929. elif modality == "audio":
  1930. possible_module_names = ["audio_tower", "audio_encoder", "speech_encoder"]
  1931. elif modality is None:
  1932. possible_module_names = ["text_encoder", "encoder"]
  1933. else:
  1934. raise ValueError(f'Unnrecognized modality, has to be "image", "video" or "audio" but found {modality}')
  1935. for name in possible_module_names:
  1936. if hasattr(self, name):
  1937. return getattr(self, name)
  1938. if self.base_model is not self and hasattr(self.base_model, "get_encoder"):
  1939. base_encoder = self.base_model.get_encoder(modality=modality)
  1940. # Base model will always have attr `get_encoder` if inherited from `PreTrainedModel`
  1941. # But it doesn't mean that the model has an encoder module, and we need to return `self`
  1942. if base_encoder != self.base_model:
  1943. return base_encoder
  1944. # If this is a base transformer model (no encoder/model attributes), return self
  1945. return self
  1946. def set_encoder(self, encoder, modality: str | None = None):
  1947. """
  1948. Symmetric setter. Mirrors the lookup logic used in `get_encoder`.
  1949. """
  1950. # NOTE: new models need to use existing names for layers if possible, so this list doesn't grow infinitely
  1951. if modality in ["image", "video"]:
  1952. possible_module_names = ["vision_tower", "visual", "vision_model", "vision_encoder", "image_tower"]
  1953. elif modality == "audio":
  1954. possible_module_names = ["audio_tower", "audio_encoder"]
  1955. elif modality is None:
  1956. possible_module_names = ["text_encoder", "encoder"]
  1957. else:
  1958. raise ValueError(f'Unnrecognized modality, has to be "image", "video" or "audio" but found {modality}')
  1959. for name in possible_module_names:
  1960. if hasattr(self, name):
  1961. setattr(self, name, encoder)
  1962. return
  1963. if self.base_model is not self:
  1964. if hasattr(self.base_model, "set_encoder"):
  1965. self.base_model.set_encoder(encoder, modality=modality)
  1966. else:
  1967. self.model = encoder
  1968. def get_decoder(self):
  1969. """
  1970. Best-effort lookup of the *decoder* module.
  1971. Order of attempts (covers ~85 % of current usages):
  1972. 1. `self.decoder/self.language_model/self.text_model`
  1973. 2. `self.base_model` (many wrappers store the decoder here)
  1974. 3. `self.base_model.get_decoder()` (nested wrappers)
  1975. 4. fallback: raise for the few exotic models that need a bespoke rule
  1976. """
  1977. possible_module_names = ["language_model", "text_model", "decoder", "text_decoder"]
  1978. for name in possible_module_names:
  1979. if hasattr(self, name):
  1980. return getattr(self, name)
  1981. if self.base_model is not self and hasattr(self.base_model, "get_decoder"):
  1982. return self.base_model.get_decoder()
  1983. # If this is a base transformer model (no decoder/model attributes), return self
  1984. # This handles cases like MistralModel which is itself the decoder
  1985. return self
  1986. def set_decoder(self, decoder):
  1987. """
  1988. Symmetric setter. Mirrors the lookup logic used in `get_decoder`.
  1989. """
  1990. possible_module_names = ["language_model", "text_model", "decoder"]
  1991. for name in possible_module_names:
  1992. if hasattr(self, name):
  1993. setattr(self, name, decoder)
  1994. return
  1995. if self.base_model is not self:
  1996. if hasattr(self.base_model, "set_decoder"):
  1997. self.base_model.set_decoder(decoder)
  1998. else:
  1999. self.model = decoder
  2000. @torch.no_grad()
  2001. def _init_weights(self, module):
  2002. """
  2003. Initialize the weights. This is quite general on purpose, in the spirit of what we usually do. For more complex
  2004. initialization scheme, it should be overridden by the derived `PreTrainedModel` class. In case a model adds an explicit
  2005. `nn.Parameter`, this method should also be overridden in order to initialize it correctly.
  2006. """
  2007. if hasattr(self.config, "initializer_range"):
  2008. std = self.config.initializer_range or 0.02
  2009. elif hasattr(self.config, "init_std"):
  2010. std = self.config.init_std
  2011. elif hasattr(self.config, "initializer_factor"):
  2012. std = self.config.initializer_factor
  2013. else:
  2014. # 0.02 is the standard default value across the library
  2015. std = getattr(self.config.get_text_config(), "initializer_range", 0.02)
  2016. if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d)):
  2017. if getattr(module, "weight", None) is not None:
  2018. init.normal_(module.weight, mean=0.0, std=std)
  2019. if module.bias is not None:
  2020. init.zeros_(module.bias)
  2021. elif isinstance(module, nn.Embedding):
  2022. init.normal_(module.weight, mean=0.0, std=std)
  2023. # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
  2024. if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
  2025. init.zeros_(module.weight[module.padding_idx])
  2026. elif isinstance(module, nn.MultiheadAttention):
  2027. # This uses torch's original init
  2028. module._reset_parameters()
  2029. # We cannot use `isinstance` on the RMSNorms or LayerNorms, as they usually are custom modules which change names
  2030. # between modelings (because they are prefixed with the model name)
  2031. elif (
  2032. isinstance(module, (nn.GroupNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d))
  2033. or "LayerNorm" in module.__class__.__name__
  2034. or "RMSNorm" in module.__class__.__name__
  2035. ):
  2036. # Norms can exist without weights (in which case they are None from torch primitives)
  2037. if getattr(module, "weight", None) is not None:
  2038. init.ones_(module.weight)
  2039. if getattr(module, "bias", None) is not None:
  2040. init.zeros_(module.bias)
  2041. # And the potential buffers for the BatchNorms
  2042. if getattr(module, "running_mean", None) is not None:
  2043. init.zeros_(module.running_mean)
  2044. init.ones_(module.running_var)
  2045. init.zeros_(module.num_batches_tracked)
  2046. # This matches all the usual RotaryEmbeddings modules
  2047. elif "RotaryEmbedding" in module.__class__.__name__ and hasattr(module, "original_inv_freq"):
  2048. rope_fn = (
  2049. ROPE_INIT_FUNCTIONS[module.rope_type]
  2050. if module.rope_type != "default"
  2051. else module.compute_default_rope_parameters
  2052. )
  2053. buffer_value, _ = rope_fn(module.config)
  2054. init.copy_(module.inv_freq, buffer_value)
  2055. init.copy_(module.original_inv_freq, buffer_value)
  2056. def _initialize_weights(self, module, is_remote_code: bool = False):
  2057. """
  2058. Initialize the weights if they are not already initialized.
  2059. """
  2060. if getattr(module, "_is_hf_initialized", False):
  2061. return
  2062. # This check is for remote code that does NOT use either `torch.init` or `transformers.initialization` in `_init_weights`
  2063. # which allow to check the flag directly on param. As they don't and write the params in-place, params would be reinitialized
  2064. # otherwise
  2065. if (
  2066. is_remote_code
  2067. and all(getattr(param, "_is_hf_initialized", False) for param in module.parameters(recurse=False))
  2068. and all(
  2069. getattr(buffer, "_is_hf_initialized", False)
  2070. for buffer in module.buffers(recurse=False)
  2071. if buffer is not None
  2072. )
  2073. ):
  2074. module._is_hf_initialized = True
  2075. return
  2076. self._init_weights(module)
  2077. module._is_hf_initialized = True
  2078. @torch.no_grad()
  2079. @init.guard_torch_init_functions()
  2080. def initialize_weights(self):
  2081. """
  2082. This is equivalent to calling `self.apply(self._initialize_weights)`, but correctly handles composite models.
  2083. This function dynamically dispatches the correct `init_weights` function to the modules as we advance in the
  2084. module graph along the recursion. It can handle an arbitrary number of sub-models. Without it, every composite
  2085. model would have to recurse a second time on all sub-models explicitly in the outer-most `_init_weights`, which
  2086. is extremely error prone and inefficient.
  2087. """
  2088. if not hasattr(torch.nn.Module, "smart_apply"):
  2089. # This function is equivalent to `torch.nn.Module.apply`, except that it dynamically adjust the function
  2090. # to apply as we go down the graph
  2091. def smart_apply(self, fn, is_remote_code):
  2092. for module in self.children():
  2093. # We found a sub-model: recursively dispatch its own init function now!
  2094. if isinstance(module, PreTrainedModel):
  2095. module.smart_apply(module._initialize_weights, is_remote_code)
  2096. else:
  2097. module.smart_apply(fn, is_remote_code)
  2098. fn(self, is_remote_code)
  2099. return self
  2100. torch.nn.Module.smart_apply = smart_apply
  2101. # Let the magic happen with this simple call
  2102. self.smart_apply(self._initialize_weights, self.is_remote_code())
  2103. def get_expanded_tied_weights_keys(self, all_submodels: bool = False) -> dict:
  2104. r"""
  2105. Return the expanded tied weight keys (in case they contain modules or regex patterns) for only the current
  2106. model, or recursively for all submodels if `all_submodels=True` (i.e. it will re-check the config values for all
  2107. submodels).
  2108. For almost all models, we only require to tie the embeddings, so the model has an internal property
  2109. `_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}`. In this case, the mapping is already
  2110. "expanded", i.e. it already contains full parameters, and this function will simply return a copy of the property.
  2111. For more complex patterns, e.g. for `DFineForObjectDetection`, we have the following attribute
  2112. ```
  2113. _tied_weights_keys = {
  2114. r"bbox_embed.(?![0])\d+": "bbox_embed.0",
  2115. r"class_embed.(?![0])\d+": "class_embed.0",
  2116. "model.decoder.class_embed": "class_embed",
  2117. "model.decoder.bbox_embed": "bbox_embed",
  2118. }
  2119. ```
  2120. In this case, the function looks up all the model's parameters and buffers, and matches all the params,
  2121. returning the following:
  2122. ```
  2123. {
  2124. 'bbox_embed.1.layers.0.bias': 'bbox_embed.0.layers.0.bias',
  2125. 'bbox_embed.1.layers.0.weight': 'bbox_embed.0.layers.0.weight',
  2126. 'bbox_embed.1.layers.1.bias': 'bbox_embed.0.layers.1.bias',
  2127. 'bbox_embed.1.layers.1.weight': 'bbox_embed.0.layers.1.weight',
  2128. 'bbox_embed.1.layers.2.bias': 'bbox_embed.0.layers.2.bias',
  2129. 'bbox_embed.1.layers.2.weight': 'bbox_embed.0.layers.2.weight',
  2130. 'bbox_embed.2.layers.0.bias': 'bbox_embed.0.layers.0.bias',
  2131. 'bbox_embed.2.layers.0.weight': 'bbox_embed.0.layers.0.weight',
  2132. ...
  2133. 'class_embed.1.bias': 'class_embed.0.bias',
  2134. 'class_embed.1.weight': 'class_embed.0.weight',
  2135. 'class_embed.2.bias': 'class_embed.0.bias',
  2136. 'class_embed.2.weight': 'class_embed.0.weight',
  2137. ...
  2138. 'model.decoder.class_embed.0.bias': 'class_embed.0.bias',
  2139. 'model.decoder.class_embed.0.weight': 'class_embed.0.weight',
  2140. 'model.decoder.class_embed.1.bias': 'class_embed.0.bias',
  2141. 'model.decoder.class_embed.1.weight': 'class_embed.0.weight',
  2142. ...
  2143. 'model.decoder.bbox_embed.0.layers.0.bias': 'bbox_embed.0.layers.0.bias',
  2144. 'model.decoder.bbox_embed.0.layers.0.weight': 'bbox_embed.0.layers.0.weight',
  2145. 'model.decoder.bbox_embed.0.layers.1.bias': 'bbox_embed.0.layers.1.bias',
  2146. 'model.decoder.bbox_embed.0.layers.1.weight': 'bbox_embed.0.layers.1.weight',
  2147. ...
  2148. }
  2149. ```
  2150. i.e. all the parameters matching the regex and modules patterns in `_tied_weights_keys`
  2151. """
  2152. if all_submodels:
  2153. expanded_tied_weights = {}
  2154. for prefix, submodule in self.named_modules(remove_duplicate=False):
  2155. if isinstance(submodule, PreTrainedModel):
  2156. # Will dynamically check the config if it has changed
  2157. submodel_tied_weights = submodule.get_expanded_tied_weights_keys(all_submodels=False)
  2158. if prefix != "":
  2159. submodel_tied_weights = {
  2160. f"{prefix}.{k}": f"{prefix}.{v}" for k, v in submodel_tied_weights.items()
  2161. }
  2162. expanded_tied_weights.update(submodel_tied_weights)
  2163. return expanded_tied_weights
  2164. tied_mapping = self._tied_weights_keys
  2165. # If the config does not specify any tying, return empty dict
  2166. # NOTE: not all modules have `tie_word_embeddings` attr, for example vision-only
  2167. # modules do not have any word embeddings!
  2168. tie_word_embeddings = getattr(self.config, "tie_word_embeddings", False)
  2169. if not tie_word_embeddings:
  2170. return {}
  2171. # If None, return empty dict
  2172. elif tied_mapping is None:
  2173. return {}
  2174. # Short-cut for the most common cases: if the tied weights mapping only contains already expanded params,
  2175. # return it directly (the regex matches names containing only letters, numbers, dots, and underscores to make
  2176. # sure it does not contain a regex pattern, and finishing by "bias" or "weight" to make sure it's not a module)
  2177. common_case_regex = re.compile(r"^[A-Za-z0-9_\.]+(weight)|(bias)$")
  2178. if all(common_case_regex.match(k) for k in tied_mapping.keys() | tied_mapping.values()):
  2179. return tied_mapping.copy()
  2180. # We need to expand the regex patterns or the modules into proper parameters
  2181. expanded_tied_weights = {}
  2182. all_param_names = {k for k, _ in self.named_parameters(remove_duplicate=False)} | {
  2183. k for k, _ in self.named_buffers(remove_duplicate=False)
  2184. }
  2185. for target_name, source_name in tied_mapping.items():
  2186. target_name = "^" + target_name
  2187. source_name = "^" + source_name
  2188. source_params = sorted(filter(lambda x: re.search(source_name, x), all_param_names))
  2189. target_params = sorted(filter(lambda x: re.search(target_name, x), all_param_names))
  2190. if (
  2191. not len(source_params) > 0
  2192. or not len(target_params) > 0
  2193. or len(target_params) % len(source_params) != 0
  2194. ):
  2195. raise ValueError(
  2196. f"There is an issue with your definition of `tie_weights_keys` for {source_name}:{target_name}. "
  2197. f"We found {source_params} to tie into {target_params}"
  2198. )
  2199. # we cycle source as it should be dispatch in many target if regex
  2200. for target_n, source_n in zip(target_params, cycle(source_params)):
  2201. # If the source is already registered as a target, use the original corresponding source. This should never
  2202. # happen in general, but some models such as `d_fine` have complicated regex patterns, so it end up being
  2203. # the case for simplicity of the regexes. Fix it silently here
  2204. if source_n in expanded_tied_weights.keys():
  2205. # Use original source instead of having keys both as source and targets
  2206. expanded_tied_weights[target_n] = expanded_tied_weights[source_n]
  2207. # Usual case, everything is already correct
  2208. else:
  2209. expanded_tied_weights[target_n] = source_n
  2210. return expanded_tied_weights
  2211. def tie_weights(self, missing_keys: set[str] | None = None, recompute_mapping: bool = True):
  2212. """
  2213. Tie the model weights. If `recompute_mapping=False` (default when called internally), it will rely on the
  2214. `model.all_tied_weights_keys` attribute, containing the `{target: source}` mapping for the tied params.
  2215. If `recompute_mapping=True`, it will re-check all internal submodels and their config to determine the params
  2216. that need to be tied. This is the default when `model.tie_weights()` is called on its own, outside of
  2217. `__init__`, and `from_pretrained`, in case the config values were changed somewhere.
  2218. Note that during `from_pretrained`, tying is *symmetric*: if the mapping says "tie target -> source" but
  2219. `source` is missing in the checkpoint while `target` exists, we *swap* source and target so we can still
  2220. tie everything to the parameter that actually exists.
  2221. """
  2222. # In this case, the keys stored in `all_tied_weights_keys` are already correct
  2223. if not recompute_mapping:
  2224. tied_keys = self.all_tied_weights_keys
  2225. else:
  2226. tied_keys = self.get_expanded_tied_weights_keys(all_submodels=True)
  2227. tied_keys = list(tied_keys.items())
  2228. for i, (target_param_name, source_param_name) in enumerate(tied_keys):
  2229. # This is `from_pretrained` -> let's check symmetrically in case the source key is not present
  2230. if missing_keys is not None:
  2231. remove_from_missing = True
  2232. source_is_there = source_param_name not in missing_keys
  2233. target_is_there = target_param_name not in missing_keys
  2234. # Both are already present -> it means the config is wrong and do not reflect the actual
  2235. # checkpoint -> let's raise a warning and NOT tie them
  2236. if source_is_there and target_is_there:
  2237. # If both are present, check if the weights are exactly similar, and only tie in this case
  2238. # This check is important, as torch `.bin` checkpoints always contain both keys, referencing the same storage
  2239. if not torch.equal(self.get_parameter(source_param_name), self.get_parameter(target_param_name)):
  2240. logger.warning(
  2241. f"The tied weights mapping and config for this model specifies to tie {source_param_name} to "
  2242. f"{target_param_name}, but both are present in the checkpoints with different values, so we will NOT "
  2243. "tie them. You should update the config with `tie_word_embeddings=False` to silence this warning."
  2244. )
  2245. # Remove from internal attribute to correctly reflect actual tied weights
  2246. self.all_tied_weights_keys.pop(target_param_name)
  2247. # Skip to next iteration
  2248. continue
  2249. # We're missing the source but we have the target -> we swap them, tying the parameter that exists
  2250. elif not source_is_there and target_is_there:
  2251. target_param_name, source_param_name = source_param_name, target_param_name
  2252. # Both are missing -> check other keys in case more than 2 keys are tied to the same weight
  2253. elif not source_is_there and not target_is_there:
  2254. for target_backup, source_backup in tied_keys[i + 1 :]:
  2255. # In case of more than 2 keys tied to the same weight, they are guaranteed to all have
  2256. # the same source thanks to `get_expanded_tied_weights_keys` so this check is enough
  2257. if source_backup == source_param_name:
  2258. target_backup_is_there = target_backup not in missing_keys
  2259. # If the target is present, we found the correct weight to tie into (we know the source is missing)
  2260. # Note here that we do not tie the missing source right now as well, as it will be done anyway when
  2261. # the pair (target_backup, source_backup) becomes the main pair (target_param_name, source_param_name)
  2262. if target_backup_is_there:
  2263. source_param_name = target_backup
  2264. break
  2265. # If we did not break from the loop, it was impossible to find a source key -> let's raise
  2266. else:
  2267. # TODO Cyril: here ideally we want to raise instead of warning, but will break our CI as we have
  2268. # tests loading model from empty dicts to perform init checks - since we don't raise, add a flag
  2269. # to NOT remove from missing keys as it's actually still missing
  2270. remove_from_missing = False
  2271. logger.warning(
  2272. f"This checkpoint seem corrupted. The tied weights mapping for this model specifies to tie "
  2273. f"{source_param_name} to {target_param_name}, but both are absent from the checkpoint, "
  2274. "and we could not find another related tied weight for those keys"
  2275. )
  2276. # Perform the actual tying
  2277. source_param = self.get_parameter_or_buffer(source_param_name)
  2278. if "." in target_param_name:
  2279. parent_name, name = target_param_name.rsplit(".", 1)
  2280. parent = self.get_submodule(parent_name)
  2281. else:
  2282. name = target_param_name
  2283. parent = self
  2284. # Tie the weights
  2285. setattr(parent, name, source_param)
  2286. self._adjust_bias(parent, source_param)
  2287. # Remove from missing if necessary
  2288. if missing_keys is not None and remove_from_missing:
  2289. missing_keys.discard(target_param_name)
  2290. def _adjust_bias(self, output_embeddings, input_embeddings):
  2291. if getattr(output_embeddings, "bias", None) is not None and hasattr(output_embeddings, "weight"):
  2292. weight_shape = output_embeddings.weight.shape
  2293. output_embeddings.bias.data = nn.functional.pad(
  2294. output_embeddings.bias.data,
  2295. (0, weight_shape[0] - output_embeddings.bias.shape[0]),
  2296. "constant",
  2297. 0,
  2298. )
  2299. if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
  2300. output_embeddings.out_features = input_embeddings.num_embeddings
  2301. def resize_token_embeddings(
  2302. self,
  2303. new_num_tokens: int | None = None,
  2304. pad_to_multiple_of: int | None = None,
  2305. mean_resizing: bool = True,
  2306. ) -> nn.Embedding:
  2307. """
  2308. Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
  2309. Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
  2310. Arguments:
  2311. new_num_tokens (`int`, *optional*):
  2312. The new number of tokens in the embedding matrix. Increasing the size will add newly initialized
  2313. vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
  2314. returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
  2315. pad_to_multiple_of (`int`, *optional*):
  2316. If set will pad the embedding matrix to a multiple of the provided value.If `new_num_tokens` is set to
  2317. `None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
  2318. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
  2319. `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
  2320. details about this, or help on choosing the correct value for resizing, refer to this guide:
  2321. https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
  2322. mean_resizing (`bool`):
  2323. Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and
  2324. covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`.
  2325. Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models,
  2326. where the generated tokens' probabilities won't be affected by the added embeddings because initializing the new embeddings with the
  2327. old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings.
  2328. Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
  2329. Return:
  2330. `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
  2331. """
  2332. model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
  2333. if new_num_tokens is None and pad_to_multiple_of is None:
  2334. return model_embeds
  2335. # Since we are basically reusing the same old embeddings with new weight values, gathering is required
  2336. is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
  2337. if is_deepspeed_zero3_enabled() and not is_quantized:
  2338. import deepspeed
  2339. with deepspeed.zero.GatheredParameters(model_embeds.weight, modifier_rank=None):
  2340. vocab_size = model_embeds.weight.shape[0]
  2341. else:
  2342. vocab_size = model_embeds.weight.shape[0]
  2343. # Update base model and current model config.
  2344. self.config.get_text_config().vocab_size = vocab_size
  2345. self.vocab_size = vocab_size
  2346. # Tie weights again if needed
  2347. self.tie_weights()
  2348. return model_embeds
  2349. def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean_resizing=True):
  2350. old_embeddings = self.get_input_embeddings()
  2351. new_embeddings = self._get_resized_embeddings(
  2352. old_embeddings, new_num_tokens, pad_to_multiple_of, mean_resizing
  2353. )
  2354. if hasattr(old_embeddings, "_hf_hook"):
  2355. hook = old_embeddings._hf_hook
  2356. add_hook_to_module(new_embeddings, hook)
  2357. old_embeddings_requires_grad = old_embeddings.weight.requires_grad
  2358. new_embeddings.requires_grad_(old_embeddings_requires_grad)
  2359. self.set_input_embeddings(new_embeddings)
  2360. is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
  2361. # Update new_num_tokens with the actual size of new_embeddings
  2362. if pad_to_multiple_of is not None:
  2363. if is_deepspeed_zero3_enabled() and not is_quantized:
  2364. import deepspeed
  2365. with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None):
  2366. new_num_tokens = new_embeddings.weight.shape[0]
  2367. else:
  2368. new_num_tokens = new_embeddings.weight.shape[0]
  2369. # if word embeddings are not tied, make sure that lm head is resized as well
  2370. if self.get_output_embeddings() is not None:
  2371. old_lm_head = self.get_output_embeddings()
  2372. if isinstance(old_lm_head, torch.nn.Embedding):
  2373. new_lm_head = self._get_resized_embeddings(old_lm_head, new_num_tokens, mean_resizing=mean_resizing)
  2374. else:
  2375. new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens, mean_resizing=mean_resizing)
  2376. if hasattr(old_lm_head, "_hf_hook"):
  2377. hook = old_lm_head._hf_hook
  2378. add_hook_to_module(new_lm_head, hook)
  2379. old_lm_head_requires_grad = old_lm_head.weight.requires_grad
  2380. new_lm_head.requires_grad_(old_lm_head_requires_grad)
  2381. self.set_output_embeddings(new_lm_head)
  2382. return self.get_input_embeddings()
  2383. def _get_resized_embeddings(
  2384. self,
  2385. old_embeddings: nn.Embedding,
  2386. new_num_tokens: int | None = None,
  2387. pad_to_multiple_of: int | None = None,
  2388. mean_resizing: bool = True,
  2389. ) -> nn.Embedding:
  2390. """
  2391. Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly
  2392. initialized vectors at the end. Reducing the size will remove vectors from the end
  2393. Args:
  2394. old_embeddings (`torch.nn.Embedding`):
  2395. Old embeddings to be resized.
  2396. new_num_tokens (`int`, *optional*):
  2397. New number of tokens in the embedding matrix.
  2398. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
  2399. vectors from the end. If not provided or `None`, just returns a pointer to the input tokens
  2400. `torch.nn.Embedding` module of the model without doing anything.
  2401. pad_to_multiple_of (`int`, *optional*):
  2402. If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to
  2403. `None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
  2404. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
  2405. `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
  2406. details about this, or help on choosing the correct value for resizing, refer to this guide:
  2407. https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
  2408. mean_resizing (`bool`):
  2409. Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and
  2410. covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`.
  2411. Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models,
  2412. where the generated tokens' probabilities will not be affected by the added embeddings because initializing the new embeddings with the
  2413. old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings.
  2414. Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
  2415. Return:
  2416. `torch.nn.Embedding`: Pointer to the resized Embedding Module or the old Embedding Module if
  2417. `new_num_tokens` is `None`
  2418. """
  2419. if pad_to_multiple_of is not None:
  2420. if not isinstance(pad_to_multiple_of, int):
  2421. raise ValueError(
  2422. f"Asking to pad the embedding matrix to a multiple of `{pad_to_multiple_of}`, which is not and integer. Please make sure to pass an integer"
  2423. )
  2424. if new_num_tokens is None:
  2425. new_num_tokens = old_embeddings.weight.shape[0]
  2426. new_num_tokens = ((new_num_tokens + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of
  2427. else:
  2428. logger.info(
  2429. "You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embedding"
  2430. f" dimension will be {new_num_tokens}. This might induce some performance reduction as *Tensor Cores* will not be available."
  2431. " For more details about this, or help on choosing the correct value for resizing, refer to this guide:"
  2432. " https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc"
  2433. )
  2434. if new_num_tokens is None:
  2435. return old_embeddings
  2436. is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
  2437. if is_deepspeed_zero3_enabled() and not is_quantized:
  2438. import deepspeed
  2439. with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None):
  2440. old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
  2441. else:
  2442. old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
  2443. if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled():
  2444. return old_embeddings
  2445. if not isinstance(old_embeddings, nn.Embedding):
  2446. raise TypeError(
  2447. f"Old embeddings are of type {type(old_embeddings)}, which is not an instance of {nn.Embedding}. You"
  2448. " should either use a different resize function or make sure that `old_embeddings` are an instance of"
  2449. f" {nn.Embedding}."
  2450. )
  2451. # Build new embeddings
  2452. # When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init
  2453. # because the shape of the new embedding layer is used across various modeling files
  2454. # as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading
  2455. # to errors when training.
  2456. new_embeddings = nn.Embedding(
  2457. new_num_tokens,
  2458. old_embedding_dim,
  2459. device=old_embeddings.weight.device,
  2460. dtype=old_embeddings.weight.dtype,
  2461. )
  2462. if new_num_tokens > old_num_tokens and not mean_resizing:
  2463. # initialize new embeddings (in particular added tokens) with a mean of 0 and std equals `config.initializer_range`.
  2464. self._init_weights(new_embeddings)
  2465. elif new_num_tokens > old_num_tokens and mean_resizing:
  2466. # initialize new embeddings (in particular added tokens). The new embeddings will be initialized
  2467. # from a multivariate normal distribution that has old embeddings' mean and covariance.
  2468. # as described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
  2469. logger.warning_once(
  2470. "The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. "
  2471. "As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. "
  2472. "To disable this, use `mean_resizing=False`"
  2473. )
  2474. added_num_tokens = new_num_tokens - old_num_tokens
  2475. if is_deepspeed_zero3_enabled() and not is_quantized:
  2476. import deepspeed
  2477. with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None):
  2478. self._init_added_embeddings_weights_with_mean(
  2479. old_embeddings, new_embeddings, old_num_tokens, added_num_tokens
  2480. )
  2481. else:
  2482. self._init_added_embeddings_weights_with_mean(
  2483. old_embeddings, new_embeddings, old_num_tokens, added_num_tokens
  2484. )
  2485. # Copy token embeddings from the previous weights
  2486. # numbers of tokens to copy
  2487. n = min(old_num_tokens, new_num_tokens)
  2488. if is_deepspeed_zero3_enabled() and not is_quantized:
  2489. import deepspeed
  2490. params = [old_embeddings.weight, new_embeddings.weight]
  2491. with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
  2492. new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
  2493. else:
  2494. new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
  2495. # Replace weights in old_embeddings and return to maintain the same embedding type.
  2496. # This ensures correct functionality when a Custom Embedding class is passed as input.
  2497. # The input and output embedding types remain consistent. (c.f. https://github.com/huggingface/transformers/pull/31979)
  2498. if is_deepspeed_zero3_enabled() and not is_quantized:
  2499. import deepspeed
  2500. params = [old_embeddings.weight, new_embeddings.weight]
  2501. with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
  2502. old_embeddings.weight = new_embeddings.weight
  2503. old_embeddings.num_embeddings = new_embeddings.weight.data.shape[0]
  2504. # If the new number of tokens is smaller than the original `padding_idx`, the `padding_idx`
  2505. # will be set to `None` in the resized embeddings.
  2506. if old_embeddings.padding_idx is not None and (new_num_tokens - 1) < old_embeddings.padding_idx:
  2507. old_embeddings.padding_idx = None
  2508. else:
  2509. old_embeddings.weight.data = new_embeddings.weight.data
  2510. old_embeddings.num_embeddings = new_embeddings.weight.data.shape[0]
  2511. if old_embeddings.padding_idx is not None and (new_num_tokens - 1) < old_embeddings.padding_idx:
  2512. old_embeddings.padding_idx = None
  2513. return old_embeddings
  2514. def _get_resized_lm_head(
  2515. self,
  2516. old_lm_head: nn.Linear,
  2517. new_num_tokens: int | None = None,
  2518. transposed: bool = False,
  2519. mean_resizing: bool = True,
  2520. ) -> nn.Linear:
  2521. """
  2522. Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized
  2523. vectors at the end. Reducing the size will remove vectors from the end
  2524. Args:
  2525. old_lm_head (`torch.nn.Linear`):
  2526. Old lm head liner layer to be resized.
  2527. new_num_tokens (`int`, *optional*):
  2528. New number of tokens in the linear matrix.
  2529. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
  2530. vectors from the end. If not provided or `None`, just returns a pointer to the input tokens
  2531. `torch.nn.Linear` module of the model without doing anything. transposed (`bool`, *optional*, defaults
  2532. to `False`): Whether `old_lm_head` is transposed or not. If True `old_lm_head.size()` is `lm_head_dim,
  2533. vocab_size` else `vocab_size, lm_head_dim`.
  2534. mean_resizing (`bool`):
  2535. Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and
  2536. covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`.
  2537. Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models,
  2538. where the generated tokens' probabilities will not be affected by the added embeddings because initializing the new embeddings with the
  2539. old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings.
  2540. Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
  2541. Return:
  2542. `torch.nn.Linear`: Pointer to the resized Linear Module or the old Linear Module if `new_num_tokens` is
  2543. `None`
  2544. """
  2545. if new_num_tokens is None:
  2546. return old_lm_head
  2547. is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
  2548. if is_deepspeed_zero3_enabled() and not is_quantized:
  2549. import deepspeed
  2550. with deepspeed.zero.GatheredParameters(old_lm_head.weight, modifier_rank=None):
  2551. old_num_tokens, old_lm_head_dim = (
  2552. old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
  2553. )
  2554. else:
  2555. old_num_tokens, old_lm_head_dim = (
  2556. old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
  2557. )
  2558. if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled():
  2559. return old_lm_head
  2560. if not isinstance(old_lm_head, nn.Linear):
  2561. raise TypeError(
  2562. f"Old language model head is of type {type(old_lm_head)}, which is not an instance of {nn.Linear}. You"
  2563. " should either use a different resize function or make sure that `old_lm_head` are an instance of"
  2564. f" {nn.Linear}."
  2565. )
  2566. # Build new lm head
  2567. new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim)
  2568. has_new_lm_head_bias = old_lm_head.bias is not None
  2569. # When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init
  2570. # because the shape of the new embedding layer is used across various modeling files
  2571. # as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading
  2572. # to errors when training.
  2573. new_lm_head = nn.Linear(
  2574. *new_lm_head_shape,
  2575. bias=has_new_lm_head_bias,
  2576. device=old_lm_head.weight.device,
  2577. dtype=old_lm_head.weight.dtype,
  2578. )
  2579. if new_num_tokens > old_num_tokens and not mean_resizing:
  2580. # initialize new embeddings (in particular added tokens) with a mean of 0 and std equals `config.initializer_range`.
  2581. self._init_weights(new_lm_head)
  2582. elif new_num_tokens > old_num_tokens and mean_resizing:
  2583. # initialize new lm_head weights (in particular added tokens). The new lm_head weights
  2584. # will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance.
  2585. # as described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
  2586. logger.warning_once(
  2587. "The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. "
  2588. "As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. "
  2589. "To disable this, use `mean_resizing=False`"
  2590. )
  2591. added_num_tokens = new_num_tokens - old_num_tokens
  2592. if is_deepspeed_zero3_enabled() and not is_quantized:
  2593. import deepspeed
  2594. params = [old_lm_head.weight]
  2595. if has_new_lm_head_bias:
  2596. params += [old_lm_head.bias]
  2597. with deepspeed.zero.GatheredParameters(params, modifier_rank=None):
  2598. self._init_added_lm_head_weights_with_mean(
  2599. old_lm_head, new_lm_head, old_lm_head_dim, old_num_tokens, added_num_tokens, transposed
  2600. )
  2601. if has_new_lm_head_bias:
  2602. self._init_added_lm_head_bias_with_mean(old_lm_head, new_lm_head, added_num_tokens)
  2603. else:
  2604. self._init_added_lm_head_weights_with_mean(
  2605. old_lm_head, new_lm_head, old_lm_head_dim, old_num_tokens, added_num_tokens, transposed
  2606. )
  2607. if has_new_lm_head_bias:
  2608. self._init_added_lm_head_bias_with_mean(old_lm_head, new_lm_head, added_num_tokens)
  2609. num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
  2610. if is_deepspeed_zero3_enabled() and not is_quantized:
  2611. import deepspeed
  2612. params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias]
  2613. with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
  2614. self._copy_lm_head_original_to_resized(
  2615. new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
  2616. )
  2617. else:
  2618. self._copy_lm_head_original_to_resized(
  2619. new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
  2620. )
  2621. new_lm_head._is_hf_initialized = True
  2622. return new_lm_head
  2623. def _init_added_embeddings_weights_with_mean(
  2624. self, old_embeddings, new_embeddings, old_num_tokens, added_num_tokens
  2625. ):
  2626. old_embeddings_weight = old_embeddings.weight.data.to(torch.float32)
  2627. mean_embeddings = torch.mean(old_embeddings_weight, axis=0)
  2628. old_centered_embeddings = old_embeddings_weight - mean_embeddings
  2629. covariance = old_centered_embeddings.T @ old_centered_embeddings / old_num_tokens
  2630. # Check if the covariance is positive definite.
  2631. epsilon = 1e-9
  2632. is_covariance_psd = constraints.positive_definite.check(epsilon * covariance).all()
  2633. if is_covariance_psd:
  2634. # If covariances is positive definite, a distribution can be created. and we can sample new weights from it.
  2635. distribution = torch.distributions.multivariate_normal.MultivariateNormal(
  2636. mean_embeddings, covariance_matrix=epsilon * covariance
  2637. )
  2638. new_embeddings.weight.data[-1 * added_num_tokens :, :] = distribution.sample(
  2639. sample_shape=(added_num_tokens,)
  2640. ).to(old_embeddings.weight.dtype)
  2641. else:
  2642. # Otherwise, just initialize with the mean. because distribution will not be created.
  2643. new_embeddings.weight.data[-1 * added_num_tokens :, :] = (
  2644. mean_embeddings[None, :].repeat(added_num_tokens, 1).to(old_embeddings.weight.dtype)
  2645. )
  2646. def _init_added_lm_head_weights_with_mean(
  2647. self,
  2648. old_lm_head,
  2649. new_lm_head,
  2650. old_lm_head_dim,
  2651. old_num_tokens,
  2652. added_num_tokens,
  2653. transposed: bool = False,
  2654. ):
  2655. if transposed:
  2656. # Transpose to the desired shape for the function.
  2657. new_lm_head.weight.data = new_lm_head.weight.data.T
  2658. old_lm_head.weight.data = old_lm_head.weight.data.T
  2659. # The same initialization logic as Embeddings.
  2660. self._init_added_embeddings_weights_with_mean(old_lm_head, new_lm_head, old_num_tokens, added_num_tokens)
  2661. if transposed:
  2662. # Transpose again to the correct shape.
  2663. new_lm_head.weight.data = new_lm_head.weight.data.T
  2664. old_lm_head.weight.data = old_lm_head.weight.data.T
  2665. def _init_added_lm_head_bias_with_mean(self, old_lm_head, new_lm_head, added_num_tokens):
  2666. bias_mean = torch.mean(old_lm_head.bias.data, axis=0, dtype=torch.float32)
  2667. bias_std = torch.std(old_lm_head.bias.data, axis=0).to(torch.float32)
  2668. new_lm_head.bias.data[-1 * added_num_tokens :].normal_(mean=bias_mean, std=1e-9 * bias_std)
  2669. def _copy_lm_head_original_to_resized(
  2670. self, new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
  2671. ):
  2672. # Copy old lm head weights to new lm head
  2673. if not transposed:
  2674. new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]
  2675. else:
  2676. new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy]
  2677. # Copy bias weights to new lm head
  2678. if has_new_lm_head_bias:
  2679. new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy]
  2680. def resize_position_embeddings(self, new_num_position_embeddings: int):
  2681. raise NotImplementedError(
  2682. f"`resize_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
  2683. f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
  2684. )
  2685. def get_position_embeddings(self) -> nn.Embedding | tuple[nn.Embedding]:
  2686. raise NotImplementedError(
  2687. f"`get_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
  2688. f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
  2689. )
  2690. def init_weights(self):
  2691. """
  2692. Initialize and tie the weights if needed. If using a custom `PreTrainedModel`, you need to implement any
  2693. initialization logic in `_init_weights`.
  2694. """
  2695. # If we are initializing on meta device, there is no point in trying to run inits
  2696. if get_torch_context_manager_or_global_device() != torch.device("meta"):
  2697. # Initialize weights
  2698. self.initialize_weights()
  2699. # Tie weights needs to be called here, but it can use the pre-computed `all_tied_weights_keys`
  2700. self.tie_weights(recompute_mapping=False)
  2701. def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
  2702. """
  2703. Activates gradient checkpointing for the current model.
  2704. We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of
  2705. the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
  2706. Args:
  2707. gradient_checkpointing_kwargs (dict, *optional*):
  2708. Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function.
  2709. """
  2710. if not self.supports_gradient_checkpointing:
  2711. raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
  2712. if gradient_checkpointing_kwargs is None:
  2713. gradient_checkpointing_kwargs = {"use_reentrant": False}
  2714. gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)
  2715. # For old GC format (transformers < 4.35.0) for models that live on the Hub
  2716. # we will fall back to the overwritten `_set_gradient_checkpointing` method
  2717. _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters
  2718. if not _is_using_old_format:
  2719. self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
  2720. else:
  2721. self.apply(partial(self._set_gradient_checkpointing, value=True))
  2722. logger.warning(
  2723. "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
  2724. "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
  2725. )
  2726. needs_embedding_grads = self.main_input_name == "input_ids"
  2727. # we use that also to detect whether or not we have to raise if embeddings are missing (the submodel might not have embeddings at all)
  2728. enable_input_grads = needs_embedding_grads or getattr(self, "_hf_peft_config_loaded", False)
  2729. if enable_input_grads:
  2730. # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
  2731. # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
  2732. # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
  2733. # the gradients to make sure the gradient flows.
  2734. self.enable_input_require_grads()
  2735. def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint):
  2736. is_gradient_checkpointing_set = False
  2737. # Apply it on the top-level module in case the top-level modules supports it
  2738. # for example, LongT5Stack inherits from `PreTrainedModel`.
  2739. if hasattr(self, "gradient_checkpointing"):
  2740. self._gradient_checkpointing_func = gradient_checkpointing_func
  2741. self.gradient_checkpointing = enable
  2742. is_gradient_checkpointing_set = True
  2743. for module in self.modules():
  2744. if hasattr(module, "gradient_checkpointing"):
  2745. module._gradient_checkpointing_func = gradient_checkpointing_func
  2746. module.gradient_checkpointing = enable
  2747. is_gradient_checkpointing_set = True
  2748. if not is_gradient_checkpointing_set:
  2749. raise ValueError(
  2750. f"{self.__class__.__name__} is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute"
  2751. " `gradient_checkpointing` to modules of the model that uses checkpointing."
  2752. )
  2753. def gradient_checkpointing_disable(self):
  2754. """
  2755. Deactivates gradient checkpointing for the current model.
  2756. """
  2757. if self.supports_gradient_checkpointing:
  2758. # For old GC format (transformers < 4.35.0) for models that live on the Hub
  2759. # we will fall back to the overwritten `_set_gradient_checkpointing` method
  2760. _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters
  2761. if not _is_using_old_format:
  2762. self._set_gradient_checkpointing(enable=False)
  2763. else:
  2764. logger.warning(
  2765. "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
  2766. "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
  2767. )
  2768. self.apply(partial(self._set_gradient_checkpointing, value=False))
  2769. if getattr(self, "_hf_peft_config_loaded", False):
  2770. self.disable_input_require_grads()
  2771. @property
  2772. def is_gradient_checkpointing(self) -> bool:
  2773. """
  2774. Whether gradient checkpointing is activated for this model or not.
  2775. """
  2776. return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
  2777. def save_pretrained(
  2778. self,
  2779. save_directory: str | os.PathLike,
  2780. is_main_process: bool = True,
  2781. state_dict: dict | None = None,
  2782. push_to_hub: bool = False,
  2783. max_shard_size: int | str = "50GB",
  2784. variant: str | None = None,
  2785. token: str | bool | None = None,
  2786. save_peft_format: bool = True,
  2787. save_original_format: bool = True,
  2788. **kwargs,
  2789. ):
  2790. """
  2791. Save a model and its configuration file to a directory, so that it can be re-loaded using the
  2792. [`~PreTrainedModel.from_pretrained`] class method.
  2793. Arguments:
  2794. save_directory (`str` or `os.PathLike`):
  2795. Directory to which to save. Will be created if it doesn't exist.
  2796. is_main_process (`bool`, *optional*, defaults to `True`):
  2797. Whether the process calling this is the main process or not. Useful when in distributed training like
  2798. TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
  2799. the main process to avoid race conditions.
  2800. state_dict (nested dictionary of `torch.Tensor`):
  2801. The state dictionary of the model to save. Will default to `self.state_dict()`, but can be used to only
  2802. save parts of the model or if special precautions need to be taken when recovering the state dictionary
  2803. of a model (like when using model parallelism).
  2804. push_to_hub (`bool`, *optional*, defaults to `False`):
  2805. Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
  2806. repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
  2807. namespace).
  2808. max_shard_size (`int` or `str`, *optional*, defaults to `"50GB"`):
  2809. The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
  2810. lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
  2811. <Tip warning={true}>
  2812. If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
  2813. which will be bigger than `max_shard_size`.
  2814. </Tip>
  2815. variant (`str`, *optional*):
  2816. If specified, weights are saved in the format model.<variant>.safetensors.
  2817. token (`str` or `bool`, *optional*):
  2818. The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
  2819. the token generated when running `hf auth login` (stored in `~/.huggingface`).
  2820. save_peft_format (`bool`, *optional*, defaults to `True`):
  2821. For backward compatibility with PEFT library, in case adapter weights are attached to the model, all
  2822. keys of the state dict of adapters needs to be prepended with `base_model.model`. Advanced users can
  2823. disable this behaviours by setting `save_peft_format` to `False`.
  2824. save_original_format (`bool`, *optional*, defaults to `True`):
  2825. For backward compatibility with the previous versions of `transformers` you can save the checkpoint with
  2826. its reverse mapping. The reverse mapping needs to exists even if the model was loaded from a None legacy
  2827. checkpoint.
  2828. kwargs (`dict[str, Any]`, *optional*):
  2829. Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
  2830. """
  2831. if token is not None:
  2832. kwargs["token"] = token
  2833. _hf_peft_config_loaded = getattr(self, "_hf_peft_config_loaded", False)
  2834. hf_quantizer = getattr(self, "hf_quantizer", None)
  2835. quantization_serializable = (
  2836. hf_quantizer is not None and isinstance(hf_quantizer, HfQuantizer) and hf_quantizer.is_serializable()
  2837. )
  2838. if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable:
  2839. raise ValueError(
  2840. f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from"
  2841. " the logger on the traceback to understand the reason why the quantized model is not serializable."
  2842. )
  2843. # we need to check against tp_size, not tp_plan, as tp_plan is substituted to the class one
  2844. if self._tp_size is not None and not is_huggingface_hub_greater_or_equal("0.31.4"):
  2845. raise ImportError(
  2846. "Saving a model with tensor parallelism requires `huggingface_hub` version 0.31.4 or higher."
  2847. )
  2848. if os.path.isfile(save_directory):
  2849. logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
  2850. return
  2851. os.makedirs(save_directory, exist_ok=True)
  2852. if push_to_hub:
  2853. commit_message = kwargs.pop("commit_message", None)
  2854. repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
  2855. create_pr = kwargs.pop("create_pr", False)
  2856. repo_id = create_repo(repo_id, exist_ok=True, **kwargs).repo_id
  2857. files_timestamps = self._get_files_timestamps(save_directory)
  2858. metadata = {}
  2859. if hf_quantizer is not None:
  2860. state_dict, metadata = hf_quantizer.get_state_dict_and_metadata(self)
  2861. metadata["format"] = "pt"
  2862. # Only save the model itself if we are using distributed training
  2863. model_to_save = unwrap_model(self)
  2864. # save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
  2865. # we currently don't use this setting automatically, but may start to use with v5
  2866. dtype = model_to_save.dtype
  2867. model_to_save.config.dtype = str(dtype).split(".")[1]
  2868. # Attach architecture to the config
  2869. # When using FSDP2, unwrapping is a noop, so the model name doesn't change back to the original model name
  2870. model_to_save.config.architectures = [model_to_save.__class__.__name__.removeprefix("FSDP")]
  2871. # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be
  2872. # loaded from the Hub.
  2873. if self._auto_class is not None:
  2874. custom_object_save(self, save_directory, config=self.config)
  2875. # Save the config
  2876. if is_main_process:
  2877. if not _hf_peft_config_loaded:
  2878. model_to_save.config.save_pretrained(save_directory)
  2879. if self.can_generate():
  2880. model_to_save.generation_config.save_pretrained(save_directory)
  2881. if _hf_peft_config_loaded:
  2882. logger.info(
  2883. "Detected adapters on the model, saving the model in the PEFT format, only adapter weights will be saved."
  2884. )
  2885. state_dict = model_to_save.get_adapter_state_dict(state_dict=state_dict)
  2886. if save_peft_format:
  2887. logger.info(
  2888. "To match the expected format of the PEFT library, all keys of the state dict of adapters will be prepended with `base_model.model`."
  2889. )
  2890. peft_state_dict = {}
  2891. for key, value in state_dict.items():
  2892. peft_state_dict[f"base_model.model.{key}"] = value
  2893. state_dict = peft_state_dict
  2894. active_adapter = self.active_adapters()
  2895. if len(active_adapter) > 1:
  2896. raise ValueError(
  2897. "Multiple active adapters detected, saving multiple active adapters is not supported yet. You can save adapters separately one by one "
  2898. "by iteratively calling `model.set_adapter(adapter_name)` then `model.save_pretrained(...)`"
  2899. )
  2900. active_adapter = active_adapter[0]
  2901. current_peft_config = self.peft_config[active_adapter]
  2902. current_peft_config.save_pretrained(save_directory)
  2903. # Get the model state_dict
  2904. if state_dict is None:
  2905. state_dict = model_to_save.state_dict()
  2906. # if any model parameters are offloaded, we need to know it for later
  2907. is_offloaded = False
  2908. if (
  2909. hasattr(self, "hf_device_map")
  2910. and len(set(self.hf_device_map.values())) > 1
  2911. and ("cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values())
  2912. ):
  2913. is_offloaded = True
  2914. warnings.warn(
  2915. "Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory "
  2916. "exceeds the `shard_size` (50GB default)"
  2917. )
  2918. # Translate state_dict from smp to hf if saving with smp >= 1.10
  2919. if IS_SAGEMAKER_MP_POST_1_10:
  2920. for smp_to_hf, _ in smp.state.module_manager.translate_functions:
  2921. state_dict = smp_to_hf(state_dict)
  2922. # Handle the case where some state_dict keys shouldn't be saved
  2923. if self._keys_to_ignore_on_save is not None:
  2924. for ignore_key in self._keys_to_ignore_on_save:
  2925. if ignore_key in state_dict:
  2926. del state_dict[ignore_key]
  2927. # If model was sharded with TP, gather full tensors for saving
  2928. if self._tp_size is not None:
  2929. state_dict = gather_state_dict_for_save(state_dict, self._tp_plan, self._device_mesh, self._tp_size)
  2930. # Remove tied weights as safetensors do not handle them
  2931. state_dict = remove_tied_weights_from_state_dict(state_dict, model_to_save)
  2932. # Revert all renaming and/or weight operations
  2933. if save_original_format and not _hf_peft_config_loaded:
  2934. state_dict = revert_weight_conversion(model_to_save, state_dict)
  2935. # Shard the model if it is too big.
  2936. if not _hf_peft_config_loaded:
  2937. weights_name = SAFE_WEIGHTS_NAME
  2938. weights_name = _add_variant(weights_name, variant)
  2939. else:
  2940. weights_name = ADAPTER_SAFE_WEIGHTS_NAME
  2941. filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
  2942. state_dict_split = split_torch_state_dict_into_shards(
  2943. state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
  2944. )
  2945. # Save index if sharded
  2946. index = None
  2947. if state_dict_split.is_sharded:
  2948. index = {
  2949. "metadata": {"total_parameters": self.num_parameters(), **state_dict_split.metadata},
  2950. "weight_map": state_dict_split.tensor_to_filename,
  2951. }
  2952. # Clean the folder from a previous save
  2953. for filename in os.listdir(save_directory):
  2954. full_filename = os.path.join(save_directory, filename)
  2955. # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
  2956. # in distributed settings to avoid race conditions.
  2957. weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
  2958. # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
  2959. filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
  2960. reg = re.compile(r"(.*?)-\d{5}-of-\d{5}")
  2961. if (
  2962. filename.startswith(weights_no_suffix)
  2963. and os.path.isfile(full_filename)
  2964. and filename not in state_dict_split.filename_to_tensors
  2965. and is_main_process
  2966. and reg.fullmatch(filename_no_suffix) is not None
  2967. ):
  2968. os.remove(full_filename)
  2969. # Save the model
  2970. for shard_file, tensor_names in logging.tqdm(
  2971. state_dict_split.filename_to_tensors.items(), desc="Writing model shards"
  2972. ):
  2973. filename = os.path.join(save_directory, shard_file)
  2974. shard_state_dict = {}
  2975. for tensor_name in tensor_names:
  2976. # Get the tensor, and remove it from state_dict to avoid keeping the ref
  2977. tensor = state_dict.pop(tensor_name)
  2978. # If the param was offloaded, we need to load it back from disk to resave it. It's a strange pattern,
  2979. # but it would otherwise not be contained in the saved shard if we were to simply move the file
  2980. # or something
  2981. if is_offloaded and tensor.device.type == "meta":
  2982. tensor = load_offloaded_parameter(model_to_save, tensor_name)
  2983. # only do contiguous after it's permuted correctly in case of TP
  2984. shard_state_dict[tensor_name] = tensor.contiguous()
  2985. # TODO: it would be very nice to do the writing concurrently, but safetensors never releases the GIL,
  2986. # so it's not possible for now....
  2987. # Write the shard to disk
  2988. safe_save_file(shard_state_dict, filename, metadata=metadata)
  2989. # Cleanup the data before next loop (important with offloading, so we don't blowup cpu RAM)
  2990. del shard_state_dict
  2991. if index is None:
  2992. path_to_weights = os.path.join(save_directory, weights_name)
  2993. logger.info(f"Model weights saved in {path_to_weights}")
  2994. else:
  2995. save_index_file = SAFE_WEIGHTS_INDEX_NAME
  2996. save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
  2997. # Save the index as well
  2998. with open(save_index_file, "w", encoding="utf-8") as f:
  2999. content = json.dumps(index, indent=2, sort_keys=True) + "\n"
  3000. f.write(content)
  3001. logger.info(
  3002. f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
  3003. f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
  3004. f"index located at {save_index_file}."
  3005. )
  3006. if push_to_hub:
  3007. # Eventually create an empty model card
  3008. model_card = create_and_tag_model_card(repo_id, self.model_tags, token=token)
  3009. # Update model card if needed:
  3010. model_card.save(os.path.join(save_directory, "README.md"))
  3011. self._upload_modified_files(
  3012. save_directory,
  3013. repo_id,
  3014. files_timestamps,
  3015. commit_message=commit_message,
  3016. token=token,
  3017. create_pr=create_pr,
  3018. )
  3019. @wraps(PushToHubMixin.push_to_hub)
  3020. def push_to_hub(self, *args, **kwargs):
  3021. tags = self.model_tags if self.model_tags is not None else []
  3022. tags_kwargs = kwargs.get("tags", [])
  3023. if isinstance(tags_kwargs, str):
  3024. tags_kwargs = [tags_kwargs]
  3025. for tag in tags_kwargs:
  3026. if tag not in tags:
  3027. tags.append(tag)
  3028. if tags:
  3029. kwargs["tags"] = tags
  3030. return super().push_to_hub(*args, **kwargs)
  3031. def get_memory_footprint(self, return_buffers=True):
  3032. r"""
  3033. Get the memory footprint of a model. This will return the memory footprint of the current model in bytes.
  3034. Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the
  3035. PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2
  3036. Arguments:
  3037. return_buffers (`bool`, *optional*, defaults to `True`):
  3038. Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers
  3039. are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch
  3040. norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
  3041. """
  3042. mem = sum(param.nelement() * param.element_size() for param in self.parameters())
  3043. if return_buffers:
  3044. mem_bufs = sum(buf.nelement() * buf.element_size() for buf in self.buffers())
  3045. mem = mem + mem_bufs
  3046. return mem
  3047. @wraps(torch.nn.Module.cuda)
  3048. def cuda(self, *args, **kwargs):
  3049. if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
  3050. from hqq.core.quantize import HQQLinear
  3051. # Since HQQLinear stores some tensors in the 'meta' attribute,
  3052. # it's necessary to manually call the `cuda` method on HQQLinear layers.
  3053. super().cuda(*args, **kwargs)
  3054. for module in self.modules():
  3055. if isinstance(module, HQQLinear):
  3056. if len(args) > 0:
  3057. device = args[0]
  3058. else:
  3059. device = kwargs.get("device", "cuda")
  3060. module.cuda(device)
  3061. return self
  3062. # Checks if the model has been loaded in 4-bit or 8-bit with BNB
  3063. if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
  3064. if getattr(self, "is_loaded_in_8bit", False):
  3065. raise ValueError(
  3066. "Calling `cuda()` is not supported for `8-bit` quantized models. "
  3067. " Please use the model as it is, since the model has already been set to the correct devices."
  3068. )
  3069. return super().cuda(*args, **kwargs)
  3070. @wraps(torch.nn.Module.to)
  3071. def to(self, *args, **kwargs):
  3072. # For BNB/GPTQ models, we prevent users from casting the model to another dtype to restrict unwanted behaviours.
  3073. # the correct API should be to load the model with the desired dtype directly through `from_pretrained`.
  3074. dtype_present_in_args = "dtype" in kwargs
  3075. if not dtype_present_in_args:
  3076. for arg in args:
  3077. if isinstance(arg, torch.dtype):
  3078. dtype_present_in_args = True
  3079. break
  3080. if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
  3081. from hqq.core.quantize import HQQLinear
  3082. # Since HQQLinear stores some tensors in the 'meta' attribute, we must
  3083. # explicitly move the parameters to the target device for each HQQLinear layer after `to`.
  3084. super().to(*args, **kwargs)
  3085. for module in self.modules():
  3086. if isinstance(module, HQQLinear):
  3087. if "device" in kwargs:
  3088. device = kwargs["device"]
  3089. else:
  3090. device = args[0]
  3091. if "dtype" in kwargs:
  3092. dtype = kwargs["dtype"]
  3093. elif dtype_present_in_args:
  3094. dtype = arg
  3095. else:
  3096. dtype = None
  3097. # Due to the current messy implementation of HQQLinear, updating `compute_dtype`
  3098. # followed by calling the `cuda` method achieves the intended behavior of `to`,
  3099. # even when the target device is CPU.
  3100. if dtype is not None:
  3101. module.compute_dtype = dtype
  3102. module.cuda(device)
  3103. return self
  3104. if dtype_present_in_args and getattr(self, "quantization_method", None) == QuantizationMethod.QUARK:
  3105. raise ValueError("Casting a Quark quantized model to a new `dtype` is not supported.")
  3106. # Checks if the model has been loaded in 4-bit or 8-bit with BNB
  3107. if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
  3108. if dtype_present_in_args:
  3109. raise ValueError(
  3110. "You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the"
  3111. " desired `dtype` by passing the correct `dtype` argument."
  3112. )
  3113. if getattr(self, "is_loaded_in_8bit", False) and not is_bitsandbytes_available("0.48"):
  3114. raise ValueError(
  3115. "You need to install `pip install bitsandbytes>=0.48.0` if you want to move a 8-bit model across devices using to()."
  3116. )
  3117. elif getattr(self, "quantization_method", None) == QuantizationMethod.GPTQ:
  3118. if dtype_present_in_args:
  3119. raise ValueError(
  3120. "You cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired"
  3121. " `dtype` by passing the correct `dtype` argument."
  3122. )
  3123. return super().to(*args, **kwargs)
  3124. def half(self, *args):
  3125. # Checks if the model is quantized
  3126. if getattr(self, "is_quantized", False):
  3127. raise ValueError(
  3128. "`.half()` is not supported for quantized model. Please use the model as it is, since the"
  3129. " model has already been casted to the correct `dtype`."
  3130. )
  3131. else:
  3132. return super().half(*args)
  3133. def float(self, *args):
  3134. # Checks if the model is quantized
  3135. if getattr(self, "is_quantized", False):
  3136. raise ValueError(
  3137. "`.float()` is not supported for quantized model. Please use the model as it is, since the"
  3138. " model has already been casted to the correct `dtype`."
  3139. )
  3140. else:
  3141. return super().float(*args)
  3142. @classmethod
  3143. def get_init_context(
  3144. cls, dtype: torch.dtype, is_quantized: bool, _is_ds_init_called: bool, allow_all_kernels: bool | None
  3145. ):
  3146. # Need to instantiate with correct dtype
  3147. init_contexts = [local_torch_dtype(dtype, cls.__name__), init.no_tie_weights(), apply_patches()]
  3148. # Needed as we cannot forward the `allow_all_kernels` arg in the model's __init__
  3149. if allow_all_kernels:
  3150. init_contexts.append(allow_all_hub_kernels())
  3151. if is_deepspeed_zero3_enabled():
  3152. import deepspeed
  3153. # We cannot initialize the model on meta device with deepspeed when not quantized
  3154. if not is_quantized and not _is_ds_init_called:
  3155. logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
  3156. init_contexts.extend(
  3157. [
  3158. init.no_init_weights(),
  3159. deepspeed.zero.Init(config_dict_or_path=deepspeed_config()),
  3160. set_zero3_state(),
  3161. ]
  3162. )
  3163. elif is_quantized:
  3164. init_contexts.extend([torch.device("meta"), set_quantized_state()])
  3165. else:
  3166. # meta_device_safe_creation_ops patches torch.linspace to default to CPU
  3167. # so that custom models calling .item() during __init__ (e.g. drop-path
  3168. # schedules) don't crash on meta tensors.
  3169. init_contexts.extend([torch.device("meta"), init.meta_device_safe_creation_ops()])
  3170. return init_contexts
  3171. def _get_dtype_plan(self, dtype: torch.dtype) -> dict:
  3172. """Create the dtype_plan describing modules/parameters that should use the `keep_in_fp32` flag."""
  3173. dtype_plan = {}
  3174. # The _keep_in_fp32_modules flag is only used to avoid bf16 -> fp16 casting precision issues. It was introduced
  3175. # in case of force loading a model that should stay in bf16 in fp16
  3176. # See https://github.com/huggingface/transformers/issues/20287 for details.
  3177. if self._keep_in_fp32_modules is not None and dtype == torch.float16:
  3178. dtype_plan.update(dict.fromkeys(self._keep_in_fp32_modules, torch.float32))
  3179. # The _keep_in_fp32_modules_strict was introduced to always force upcast to fp32, for both fp16 and bf16
  3180. if self._keep_in_fp32_modules_strict is not None and dtype in (torch.float16, torch.bfloat16):
  3181. dtype_plan.update(dict.fromkeys(self._keep_in_fp32_modules_strict, torch.float32))
  3182. return dtype_plan
  3183. def set_use_kernels(self, use_kernels, kernel_config: KernelConfig | None = None):
  3184. """
  3185. Set whether or not to use the `kernels` library to kernelize some layers of the model.
  3186. Args:
  3187. use_kernels (`bool`):
  3188. Whether or not to use the `kernels` library to kernelize some layers of the model.
  3189. kernel_config (`KernelConfig`, *optional*):
  3190. The kernel configuration to use to kernelize the model. If `None`, the default kernel mapping will be used.
  3191. """
  3192. if use_kernels:
  3193. if not is_kernels_available():
  3194. raise ValueError(
  3195. "`use_kernels=True` requires kernels>=0.9.0. Please install the latest version with `pip install -U kernels`"
  3196. )
  3197. from kernels import use_kernel_mapping
  3198. from .integrations.hub_kernels import register_kernel_mapping_transformers
  3199. register_kernel_mapping_transformers()
  3200. if kernel_config is not None and isinstance(kernel_config, KernelConfig):
  3201. # This will make sure the mapping is valid, and the layers are registered in the model
  3202. kernel_config.sanitize_kernel_mapping(self)
  3203. # This will create a compatible mapping for the model with the kernels library
  3204. kernel_config.create_compatible_mapping(self)
  3205. # This is a context manager to override the default kernel mapping
  3206. # We are calling kernelize inside this context manager using the use_kernels setter
  3207. # Param inherit_mapping should be False to avoid still loading kernel from remote
  3208. inherit_mapping = not kernel_config.use_local_kernel
  3209. with use_kernel_mapping(kernel_config.kernel_mapping, inherit_mapping=inherit_mapping):
  3210. self.use_kernels = True
  3211. # We use the default kernel mapping in .integrations.hub_kernels
  3212. else:
  3213. self.use_kernels = True
  3214. else:
  3215. self.use_kernels = False
  3216. @classmethod
  3217. def from_pretrained(
  3218. cls: type[SpecificPreTrainedModelType],
  3219. pretrained_model_name_or_path: str | os.PathLike | None,
  3220. *model_args,
  3221. config: PreTrainedConfig | str | os.PathLike | None = None,
  3222. cache_dir: str | os.PathLike | None = None,
  3223. ignore_mismatched_sizes: bool = False,
  3224. force_download: bool = False,
  3225. local_files_only: bool = False,
  3226. token: str | bool | None = None,
  3227. revision: str = "main",
  3228. use_safetensors: bool | None = None,
  3229. weights_only: bool = True,
  3230. **kwargs,
  3231. ) -> SpecificPreTrainedModelType:
  3232. r"""
  3233. Instantiate a pretrained pytorch model from a pre-trained model configuration.
  3234. The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
  3235. the model, you should first set it back in training mode with `model.train()`.
  3236. The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
  3237. pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
  3238. task.
  3239. The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
  3240. weights are discarded.
  3241. Parameters:
  3242. pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
  3243. Can be either:
  3244. - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
  3245. - A path to a *directory* containing model weights saved using
  3246. [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
  3247. - `None` if you are both providing the configuration and state dictionary (resp. with keyword
  3248. arguments `config` and `state_dict`).
  3249. model_args (sequence of positional arguments, *optional*):
  3250. All remaining positional arguments will be passed to the underlying model's `__init__` method.
  3251. config (`Union[PreTrainedConfig, str, os.PathLike]`, *optional*):
  3252. Can be either:
  3253. - an instance of a class derived from [`PreTrainedConfig`],
  3254. - a string or path valid as input to [`~PreTrainedConfig.from_pretrained`].
  3255. Configuration for the model to use instead of an automatically loaded configuration. Configuration can
  3256. be automatically loaded when:
  3257. - The model is a model provided by the library (loaded with the *model id* string of a pretrained
  3258. model).
  3259. - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
  3260. save directory.
  3261. - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
  3262. configuration JSON file named *config.json* is found in the directory.
  3263. state_dict (`dict[str, torch.Tensor]`, *optional*):
  3264. A state dictionary to use instead of a state dictionary loaded from saved weights file.
  3265. This option can be used if you want to create a model from a pretrained configuration but load your own
  3266. weights. In this case though, you should check if using [`~PreTrainedModel.save_pretrained`] and
  3267. [`~PreTrainedModel.from_pretrained`] is not a simpler option.
  3268. cache_dir (`Union[str, os.PathLike]`, *optional*):
  3269. Path to a directory in which a downloaded pretrained model configuration should be cached if the
  3270. standard cache should not be used.
  3271. ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
  3272. Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
  3273. as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
  3274. checkpoint with 3 labels).
  3275. force_download (`bool`, *optional*, defaults to `False`):
  3276. Whether or not to force the (re-)download of the model weights and configuration files, overriding the
  3277. cached versions if they exist.
  3278. proxies (`dict[str, str]`, *optional*):
  3279. A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
  3280. 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
  3281. output_loading_info(`bool`, *optional*, defaults to `False`):
  3282. Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
  3283. local_files_only(`bool`, *optional*, defaults to `False`):
  3284. Whether or not to only look at local files (i.e., do not try to download the model).
  3285. token (`str` or `bool`, *optional*):
  3286. The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
  3287. the token generated when running `hf auth login` (stored in `~/.huggingface`).
  3288. revision (`str`, *optional*, defaults to `"main"`):
  3289. The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
  3290. git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
  3291. identifier allowed by git.
  3292. <Tip>
  3293. To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
  3294. </Tip>
  3295. attn_implementation (`str`, *optional*):
  3296. The attention implementation to use in the model (if relevant). Can be any of
  3297. - `"eager"` (manual implementation of the attention)
  3298. - `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html))
  3299. - `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention))
  3300. - `"flash_attention_3"` (using [Dao-AILab/flash-attention/hopper](https://github.com/Dao-AILab/flash-attention/tree/main/hopper))
  3301. - `"flash_attention_4"` (using [Dao-AILab/flash-attention/flash_attn/cute](https://github.com/Dao-AILab/flash-attention/tree/main/flash_attn/cute)).
  3302. By default, if available, SDPA will be used. The default is otherwise the manual `"eager"` implementation.
  3303. Accept HF kernel references in the form:
  3304. <namespace>/<repo_name>[@<revision>][:<kernel_name>]
  3305. - <namespace> and <repo_name> are any non-"/" and non-":" sequences.
  3306. - "@<revision>" is optional (branch, tag, or commit-ish), e.g. "@main", "@v1.2.0", "@abc123".
  3307. - ":<kernel_name>" is optional and selects a function inside the kernel repo.
  3308. - Both options can appear together and in this order only: @revision first, then :kernel_name.
  3309. - We intentionally allow a leading "<wrapper>|" prefix (e.g., "flash|...") because the code
  3310. strips it before loading; '|' is not excluded in the character classes here.
  3311. Examples that match:
  3312. "org/model"
  3313. "org/model@main"
  3314. "org/model:custom_kernel"
  3315. "org/model@v1.2.3:custom_kernel"
  3316. experts_implementation (`str`, *optional*):
  3317. The experts implementation to use in the model (if relevant). Can be any of:
  3318. - `"eager"` (sequential implementation of the experts matrix multiplications).
  3319. - `"batched_mm"` (using [`torch.bmm`](https://pytorch.org/docs/stable/generated/torch.bmm.html)).
  3320. - `"grouped_mm"` (using [`torch.nn.functional.grouped_mm`](https://docs.pytorch.org/docs/main/generated/torch.nn.functional.grouped_mm.html)).
  3321. By default, if the model supports it, `"grouped_mm"` will be used. The default is otherwise the manual `"eager"` implementation.
  3322. > Parameters for big model inference
  3323. dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"`):
  3324. Override the default `torch_dtype` and load the model under a specific `dtype`. The different options
  3325. are:
  3326. 1. `torch.float16` or `torch.bfloat16` or `torch.float`: load in a specified
  3327. `dtype`, ignoring the model's `config.dtype` if one exists. If not specified
  3328. - the model will get loaded in `torch.float` (fp32).
  3329. 2. `"auto"` - A `dtype` or `torch_dtype` entry in the `config.json` file of the model will be
  3330. attempted to be used. If this entry isn't found then next check the `dtype` of the first weight in
  3331. the checkpoint that's of a floating point type and use that as `dtype`. This will load the model
  3332. using the `dtype` it was saved in at the end of the training. It can't be used as an indicator of how
  3333. the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32.
  3334. 3. A string that is a valid `torch.dtype`. E.g. "float32" loads the model in `torch.float32`, "float16" loads in `torch.float16` etc.
  3335. <Tip>
  3336. For some models the `dtype` they were trained in is unknown - you may try to check the model's paper or
  3337. reach out to the authors and ask them to add this information to the model's card and to insert the
  3338. `dtype` or `torch_dtype` entry in `config.json` on the hub.
  3339. </Tip>
  3340. device_map (`str` or `dict[str, Union[int, str, torch.device]]` or `int` or `torch.device`, *optional*):
  3341. A map that specifies where each submodule should go. It doesn't need to be refined to each
  3342. parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
  3343. same device. If we only pass the device (*e.g.*, `"cpu"`, `"cuda:1"`, `"mps"`, or a GPU ordinal rank
  3344. like `1`) on which the model will be allocated, the device map will map the entire model to this
  3345. device. Passing `device_map = 0` means put the whole model on GPU 0.
  3346. To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
  3347. more information about each option see [designing a device
  3348. map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
  3349. max_memory (`Dict`, *optional*):
  3350. A dictionary device identifier to maximum memory if using `device_map`. Will default to the maximum memory available for each
  3351. GPU and the available CPU RAM if unset.
  3352. tp_plan (`Optional[Union[dict, str]]`, *optional*):
  3353. A torch tensor parallel plan, see [here](https://pytorch.org/tutorials/intermediate/TP_tutorial.html). Use `tp_plan="auto"` to
  3354. use the predefined plan based on the model. If it's a dict, then it should match between module names and desired layout.
  3355. Note that if you use it, you should launch your script accordingly with `torchrun [args] script.py`. This will be much
  3356. faster than using a `device_map`, but has limitations.
  3357. tp_size (`str`, *optional*):
  3358. A torch tensor parallel degree. If not provided would default to world size.
  3359. device_mesh (`torch.distributed.DeviceMesh`, *optional*):
  3360. A torch device mesh. If not provided would default to world size. Used only for tensor parallel for now.
  3361. If provided, it has to contain dimension named `"tp"` in case it's > 1 dimensional, this dimension will be used for tensor parallelism
  3362. offload_folder (`str` or `os.PathLike`, *optional*):
  3363. If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
  3364. offload_buffers (`bool`, *optional*):
  3365. Whether or not to offload the buffers with the model parameters.
  3366. quantization_config (`Union[QuantizationConfigMixin,Dict]`, *optional*):
  3367. A dictionary of configuration parameters or a QuantizationConfigMixin object for quantization (e.g
  3368. bitsandbytes, gptq).
  3369. subfolder (`str`, *optional*, defaults to `""`):
  3370. In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
  3371. specify the folder name here.
  3372. variant (`str`, *optional*):
  3373. If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin.
  3374. use_safetensors (`bool`, *optional*, defaults to `None`):
  3375. Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors`
  3376. is not installed, it will be set to `False`.
  3377. weights_only (`bool`, *optional*, defaults to `True`):
  3378. Indicates whether unpickler should be restricted to loading only tensors, primitive types,
  3379. dictionaries and any types added via torch.serialization.add_safe_globals().
  3380. When set to False, we can load wrapper tensor subclass weights.
  3381. key_mapping (`dict[str, str], *optional*):
  3382. A potential mapping of the weight names if using a model on the Hub which is compatible to a Transformers
  3383. architecture, but was not converted accordingly.
  3384. kwargs (remaining dictionary of keyword arguments, *optional*):
  3385. Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
  3386. `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
  3387. automatically loaded:
  3388. - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
  3389. underlying model's `__init__` method (we assume all relevant updates to the configuration have
  3390. already been done)
  3391. - If a configuration is not provided, `kwargs` will be first passed to the configuration class
  3392. initialization function ([`~PreTrainedConfig.from_pretrained`]). Each key of `kwargs` that
  3393. corresponds to a configuration attribute will be used to override said attribute with the
  3394. supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
  3395. will be passed to the underlying model's `__init__` function.
  3396. <Tip>
  3397. Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
  3398. use this method in a firewalled environment.
  3399. </Tip>
  3400. Examples:
  3401. ```python
  3402. >>> from transformers import BertConfig, BertModel
  3403. >>> # Download model and configuration from huggingface.co and cache.
  3404. >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased")
  3405. >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
  3406. >>> model = BertModel.from_pretrained("./test/saved_model/")
  3407. >>> # Update configuration during loading.
  3408. >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased", output_attentions=True)
  3409. >>> assert model.config.output_attentions == True
  3410. ```
  3411. """
  3412. state_dict = kwargs.pop("state_dict", None)
  3413. proxies = kwargs.pop("proxies", None)
  3414. tqdm_class = kwargs.pop("tqdm_class", None)
  3415. output_loading_info = kwargs.pop("output_loading_info", False)
  3416. from_pipeline = kwargs.pop("_from_pipeline", None)
  3417. from_auto_class = kwargs.pop("_from_auto", False)
  3418. dtype = kwargs.pop("dtype", None)
  3419. torch_dtype = kwargs.pop("torch_dtype", None) # kept for BC
  3420. device_map = kwargs.pop("device_map", None)
  3421. max_memory = kwargs.pop("max_memory", None)
  3422. offload_folder = kwargs.pop("offload_folder", None)
  3423. offload_buffers = kwargs.pop("offload_buffers", False)
  3424. quantization_config = kwargs.pop("quantization_config", None)
  3425. subfolder = kwargs.pop("subfolder", "")
  3426. commit_hash = kwargs.pop("_commit_hash", None)
  3427. variant = kwargs.pop("variant", None)
  3428. adapter_kwargs = (kwargs.pop("adapter_kwargs", {}) or {}).copy()
  3429. adapter_name = kwargs.pop("adapter_name", "default")
  3430. generation_config = kwargs.pop("generation_config", None)
  3431. gguf_file = kwargs.pop("gguf_file", None)
  3432. tp_plan = kwargs.pop("tp_plan", None)
  3433. tp_size = kwargs.pop("tp_size", None)
  3434. distributed_config: DistributedConfig = kwargs.pop("distributed_config", None)
  3435. device_mesh = kwargs.pop("device_mesh", None)
  3436. trust_remote_code = kwargs.pop("trust_remote_code", None)
  3437. allow_all_kernels = kwargs.pop("allow_all_kernels", False)
  3438. use_kernels = kwargs.pop("use_kernels", False)
  3439. kernel_config = kwargs.pop("kernel_config", None)
  3440. key_mapping = kwargs.pop("key_mapping", None)
  3441. if distributed_config is not None and tp_plan is None:
  3442. tp_plan = "auto"
  3443. # Not used anymore -- remove them from the kwargs
  3444. for name in ["mirror", "_fast_init", "low_cpu_mem_usage", "from_tf", "from_flax", "offload_state_dict"]:
  3445. _ = kwargs.pop(name, None)
  3446. # For BC on torch_dtype argument
  3447. if torch_dtype is not None:
  3448. dtype = dtype if dtype is not None else torch_dtype
  3449. if dtype is None:
  3450. dtype = "auto"
  3451. if is_offline_mode() and not local_files_only:
  3452. local_files_only = True
  3453. download_kwargs = {
  3454. "cache_dir": cache_dir,
  3455. "force_download": force_download,
  3456. "proxies": proxies,
  3457. "local_files_only": local_files_only,
  3458. "token": token,
  3459. "revision": revision,
  3460. "subfolder": subfolder,
  3461. }
  3462. download_kwargs_with_commit = {**download_kwargs, "commit_hash": commit_hash}
  3463. if state_dict is not None and (pretrained_model_name_or_path is not None or gguf_file is not None):
  3464. raise ValueError(
  3465. "`state_dict` cannot be passed together with a model name or a `gguf_file`. Use one of the two loading strategies."
  3466. )
  3467. if device_map == "auto" and int(os.environ.get("WORLD_SIZE", "0")):
  3468. logger.info(
  3469. "You've set device_map=`auto` while triggering a distributed run with torchrun. This might lead to unexpected behavior. "
  3470. "If your plan is to load the model on each device, you should set device_map={"
  3471. ": PartialState().process_index} where PartialState comes from accelerate library"
  3472. )
  3473. if tp_plan is not None or tp_size is not None: # TP warnings, and setup
  3474. device_map, device_mesh, tp_size = initialize_tensor_parallelism(
  3475. tp_plan, tp_size=tp_size, device_mesh=device_mesh, device_map=device_map
  3476. )
  3477. if gguf_file is not None and not is_accelerate_available():
  3478. raise ValueError("accelerate is required when loading a GGUF file `pip install accelerate`.")
  3479. if adapter_kwargs is None:
  3480. adapter_kwargs = {}
  3481. _adapter_model_path, pretrained_model_name_or_path, adapter_kwargs = maybe_load_adapters(
  3482. pretrained_model_name_or_path,
  3483. download_kwargs_with_commit,
  3484. **adapter_kwargs,
  3485. )
  3486. device_map = check_and_set_device_map(device_map) # warn, error and fix the device map
  3487. user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
  3488. if from_pipeline is not None:
  3489. user_agent["using_pipeline"] = from_pipeline
  3490. # Load config if we don't provide a configuration
  3491. if not isinstance(config, PreTrainedConfig):
  3492. config_path = config if config is not None else pretrained_model_name_or_path
  3493. config, model_kwargs = cls.config_class.from_pretrained(
  3494. config_path,
  3495. return_unused_kwargs=True,
  3496. gguf_file=gguf_file,
  3497. _from_auto=from_auto_class,
  3498. _from_pipeline=from_pipeline,
  3499. **download_kwargs,
  3500. **kwargs,
  3501. )
  3502. if "gguf_file" in model_kwargs:
  3503. model_kwargs.pop("gguf_file")
  3504. commit_hash = model_kwargs.pop("_commit_hash", commit_hash)
  3505. else:
  3506. config = copy.deepcopy(config)
  3507. model_kwargs = kwargs
  3508. commit_hash = getattr(config, "_commit_hash", commit_hash)
  3509. download_kwargs_with_commit["commit_hash"] = commit_hash
  3510. # Because some composite configs call super().__init__ before instantiating the sub-configs, we need this call
  3511. # to correctly redispatch recursively if the kwarg is provided
  3512. if "attn_implementation" in kwargs:
  3513. config._attn_implementation = kwargs.pop("attn_implementation")
  3514. if "experts_implementation" in kwargs:
  3515. config._experts_implementation = kwargs.pop("experts_implementation")
  3516. hf_quantizer, config, device_map = get_hf_quantizer(
  3517. config, quantization_config, device_map, weights_only, user_agent
  3518. )
  3519. if gguf_file:
  3520. if hf_quantizer is not None:
  3521. raise ValueError(
  3522. "You cannot combine Quantization and loading a model from a GGUF file, try again by making sure you did not passed a `quantization_config` or that you did not load a quantized model from the Hub."
  3523. )
  3524. if device_map is not None and (
  3525. (isinstance(device_map, dict) and "disk" in device_map.values()) or "disk" in device_map
  3526. ):
  3527. raise RuntimeError(
  3528. "One or more modules is configured to be mapped to disk. Disk offload is not supported for models "
  3529. "loaded from GGUF files."
  3530. )
  3531. if kernel_config is not None and not use_kernels:
  3532. logger.warning_once(
  3533. "A kernel_config was provided but use_kernels is False; setting use_kernels=True automatically. To suppress this warning, explicitly set use_kernels to True."
  3534. )
  3535. use_kernels = True
  3536. checkpoint_files, sharded_metadata = _get_resolved_checkpoint_files(
  3537. pretrained_model_name_or_path=pretrained_model_name_or_path,
  3538. variant=variant,
  3539. gguf_file=gguf_file,
  3540. use_safetensors=use_safetensors,
  3541. download_kwargs=download_kwargs_with_commit,
  3542. user_agent=user_agent,
  3543. is_remote_code=cls.is_remote_code(),
  3544. transformers_explicit_filename=getattr(config, "transformers_weights", None),
  3545. tqdm_class=tqdm_class,
  3546. )
  3547. is_quantized = hf_quantizer is not None
  3548. if gguf_file:
  3549. from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
  3550. # we need a dummy model to get the state_dict - for this reason, we keep the state_dict as if it was
  3551. # passed directly as a kwarg from now on
  3552. with torch.device("meta"):
  3553. dummy_model = cls(config)
  3554. state_dict = load_gguf_checkpoint(checkpoint_files[0], return_tensors=True, model_to_load=dummy_model)[
  3555. "tensors"
  3556. ]
  3557. # Find the correct dtype based on current state
  3558. config, dtype = _get_dtype(
  3559. dtype, checkpoint_files, config, sharded_metadata, state_dict, weights_only, hf_quantizer
  3560. )
  3561. config.name_or_path = pretrained_model_name_or_path
  3562. model_init_context = cls.get_init_context(dtype, is_quantized, _is_ds_init_called, allow_all_kernels)
  3563. config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
  3564. with ContextManagers(model_init_context):
  3565. model = cls(config, *model_args, **model_kwargs)
  3566. patch_output_recorders(model)
  3567. if hf_quantizer is not None: # replace module with quantized modules (does not touch weights)
  3568. hf_quantizer.preprocess_model(
  3569. model=model,
  3570. dtype=dtype,
  3571. device_map=device_map,
  3572. checkpoint_files=checkpoint_files,
  3573. use_kernels=use_kernels,
  3574. )
  3575. # Create the dtype_plan to potentially use the `keep_in_fp32` flags (this needs to be called on the already
  3576. # instantiated model, as the flags can be modified by instances sometimes)
  3577. dtype_plan = model._get_dtype_plan(dtype)
  3578. # Obtain the weight conversion mapping for this model if any are registered and apply to all submodels recursively
  3579. weight_conversions = get_model_conversion_mapping(model, key_mapping, hf_quantizer)
  3580. if _torch_distributed_available and device_mesh is not None: # add hooks to nn.Modules: no weights
  3581. model = distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size)
  3582. # Prepare the full device map
  3583. if device_map is not None:
  3584. device_map = _get_device_map(model, device_map, max_memory, hf_quantizer)
  3585. # Finalize model weight initialization
  3586. load_config = LoadStateDictConfig(
  3587. pretrained_model_name_or_path=pretrained_model_name_or_path,
  3588. ignore_mismatched_sizes=ignore_mismatched_sizes,
  3589. sharded_metadata=sharded_metadata,
  3590. device_map=device_map,
  3591. disk_offload_folder=offload_folder,
  3592. offload_buffers=offload_buffers,
  3593. dtype=dtype,
  3594. dtype_plan=dtype_plan,
  3595. hf_quantizer=hf_quantizer,
  3596. device_mesh=device_mesh,
  3597. weights_only=weights_only,
  3598. weight_mapping=weight_conversions,
  3599. use_safetensors=use_safetensors,
  3600. download_kwargs=download_kwargs,
  3601. )
  3602. loading_info, disk_offload_index = cls._load_pretrained_model(model, state_dict, checkpoint_files, load_config)
  3603. loading_info = cls._finalize_model_loading(model, load_config, loading_info)
  3604. model.eval() # Set model in evaluation mode to deactivate Dropout modules by default
  3605. model.set_use_kernels(use_kernels, kernel_config)
  3606. # If it is a model with generation capabilities, attempt to load generation files (generation config,
  3607. # custom generate function)
  3608. if model.can_generate() and hasattr(model, "adjust_generation_fn") and not gguf_file:
  3609. model.adjust_generation_fn(
  3610. generation_config,
  3611. from_auto_class,
  3612. from_pipeline,
  3613. pretrained_model_name_or_path,
  3614. **download_kwargs,
  3615. trust_remote_code=trust_remote_code,
  3616. **kwargs,
  3617. )
  3618. # If the device_map has more than 1 device: dispatch model with hooks on all devices
  3619. if device_map is not None and len(set(device_map.values())) > 1:
  3620. accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, disk_offload_index, offload_buffers)
  3621. if hf_quantizer is not None:
  3622. model.hf_quantizer = hf_quantizer
  3623. hf_quantizer.postprocess_model(
  3624. model
  3625. ) # usually a no-op but sometimes needed, e.g to remove the quant config when dequantizing
  3626. if _adapter_model_path is not None:
  3627. if token is not None:
  3628. adapter_kwargs["token"] = token
  3629. loading_info = model.load_adapter(
  3630. _adapter_model_path,
  3631. adapter_name=adapter_name,
  3632. load_config=load_config,
  3633. adapter_kwargs=adapter_kwargs,
  3634. )
  3635. if output_loading_info:
  3636. return model, loading_info.to_dict()
  3637. return model
  3638. @staticmethod
  3639. def _load_pretrained_model(
  3640. model: "PreTrainedModel",
  3641. state_dict: dict | None,
  3642. checkpoint_files: list[str] | None,
  3643. load_config: LoadStateDictConfig,
  3644. expected_keys: list[str] | None = None,
  3645. ) -> tuple[LoadStateDictInfo, dict]:
  3646. """Perform the actual loading of some checkpoints into a `model`, by reading them from disk and dispatching them accordingly."""
  3647. is_quantized = load_config.is_quantized
  3648. is_hqq_or_quark = is_quantized and load_config.hf_quantizer.quantization_config.quant_method in {
  3649. QuantizationMethod.HQQ,
  3650. QuantizationMethod.QUARK,
  3651. }
  3652. # Model's definition arriving here is final (TP hooks added, quantized layers replaces)
  3653. expected_keys = list(model.state_dict().keys()) if expected_keys is None else expected_keys
  3654. if logger.level >= logging.WARNING:
  3655. verify_tp_plan(expected_keys, getattr(model, "_tp_plan", None))
  3656. # This offload index if for params explicitly on the "disk" in the device_map
  3657. disk_offload_index = None
  3658. # Prepare parameters offloading if needed
  3659. if load_config.device_map is not None and "disk" in load_config.device_map.values():
  3660. disk_offload_index = accelerate_disk_offload(
  3661. model,
  3662. load_config.disk_offload_folder,
  3663. checkpoint_files,
  3664. load_config.device_map,
  3665. load_config.sharded_metadata,
  3666. load_config.dtype,
  3667. load_config.weight_mapping,
  3668. )
  3669. # Warmup cuda to load the weights much faster on devices
  3670. if load_config.device_map is not None and not is_hqq_or_quark:
  3671. expanded_device_map = expand_device_map(load_config.device_map, expected_keys)
  3672. caching_allocator_warmup(model, expanded_device_map, load_config.hf_quantizer)
  3673. error_msgs = []
  3674. if is_deepspeed_zero3_enabled() and not is_quantized:
  3675. if state_dict is None:
  3676. merged_state_dict = {}
  3677. for ckpt_file in checkpoint_files:
  3678. merged_state_dict.update(
  3679. load_state_dict(ckpt_file, map_location="cpu", weights_only=load_config.weights_only)
  3680. )
  3681. state_dict = merged_state_dict
  3682. error_msgs, missing_keys = _load_state_dict_into_zero3_model(model, state_dict, load_config)
  3683. # This is not true but for now we assume only best-case scenario with deepspeed, i.e. perfectly matching checkpoints
  3684. loading_info = LoadStateDictInfo(
  3685. missing_keys=missing_keys,
  3686. error_msgs=error_msgs,
  3687. unexpected_keys=set(),
  3688. mismatched_keys=set(),
  3689. conversion_errors={},
  3690. )
  3691. else:
  3692. all_pointer = set()
  3693. if state_dict is not None:
  3694. merged_state_dict = state_dict
  3695. elif checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors") and state_dict is None:
  3696. merged_state_dict = {}
  3697. for file in checkpoint_files:
  3698. file_pointer = safe_open(file, framework="pt", device="cpu")
  3699. all_pointer.add(file_pointer)
  3700. for k in file_pointer.keys():
  3701. merged_state_dict[k] = file_pointer.get_slice(k) # don't materialize yet
  3702. # Checkpoints are .bin
  3703. elif checkpoint_files is not None:
  3704. merged_state_dict = {}
  3705. for ckpt_file in checkpoint_files:
  3706. merged_state_dict.update(load_state_dict(ckpt_file))
  3707. else:
  3708. raise ValueError("Neither a state dict nor checkpoint files were found.")
  3709. loading_info, disk_offload_index = convert_and_load_state_dict_in_model(
  3710. model=model,
  3711. state_dict=merged_state_dict,
  3712. load_config=load_config,
  3713. tp_plan=model._tp_plan,
  3714. disk_offload_index=disk_offload_index,
  3715. )
  3716. # finally close all opened file pointers
  3717. for k in all_pointer:
  3718. k.__exit__(None, None, None)
  3719. return loading_info, disk_offload_index
  3720. @staticmethod
  3721. def _finalize_model_loading(
  3722. model, load_config: LoadStateDictConfig, loading_info: LoadStateDictInfo
  3723. ) -> LoadStateDictInfo:
  3724. """Perform all post processing operations after having loaded some checkpoints into a model, such as moving
  3725. missing keys from meta device to their expected device, reinitializing missing weights according to proper
  3726. distributions, tying the weights and logging the loading report."""
  3727. try:
  3728. # Marks tied weights as `_is_hf_initialized` to avoid initializing them (it's very important for efficiency)
  3729. model.mark_tied_weights_as_initialized(loading_info)
  3730. # Move missing (and potentially mismatched) keys and non-persistent buffers back to their expected device from
  3731. # meta device (because they were not moved when loading the weights as they were not in the loaded state dict)
  3732. model._move_missing_keys_from_meta_to_device(
  3733. loading_info.missing_and_mismatched(),
  3734. load_config.device_map,
  3735. load_config.device_mesh,
  3736. load_config.hf_quantizer,
  3737. )
  3738. # Correctly initialize the missing (and potentially mismatched) keys (all parameters without the `_is_hf_initialized` flag)
  3739. model._initialize_missing_keys(load_config.is_quantized)
  3740. # Tie the weights
  3741. model.tie_weights(missing_keys=loading_info.missing_keys, recompute_mapping=False)
  3742. # Adjust missing and unexpected keys
  3743. model._adjust_missing_and_unexpected_keys(loading_info)
  3744. finally:
  3745. log_state_dict_report(
  3746. model=model,
  3747. pretrained_model_name_or_path=load_config.pretrained_model_name_or_path,
  3748. ignore_mismatched_sizes=load_config.ignore_mismatched_sizes,
  3749. loading_info=loading_info,
  3750. logger=logger,
  3751. )
  3752. return loading_info
  3753. def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False):
  3754. module_keys = {".".join(key.split(".")[:-1]) for key in names}
  3755. # torch.nn.ParameterList is a special case where two parameter keywords
  3756. # are appended to the module name, *e.g.* bert.special_embeddings.0
  3757. module_keys = module_keys.union(
  3758. {".".join(key.split(".")[:-2]) for key in names if len(key) > 0 and key[-1].isdigit()}
  3759. )
  3760. retrieved_modules = []
  3761. # retrieve all modules that has at least one missing weight name
  3762. for name, module in self.named_modules():
  3763. if remove_prefix:
  3764. _prefix = f"{self.base_model_prefix}."
  3765. name = name.removeprefix(_prefix)
  3766. elif add_prefix:
  3767. name = ".".join([self.base_model_prefix, name]) if len(name) > 0 else self.base_model_prefix
  3768. if name in module_keys:
  3769. retrieved_modules.append(module)
  3770. return retrieved_modules
  3771. @classmethod
  3772. def register_for_auto_class(cls, auto_class="AutoModel"):
  3773. """
  3774. Register this class with a given auto class. This should only be used for custom models as the ones in the
  3775. library are already mapped with an auto class.
  3776. Args:
  3777. auto_class (`str` or `type`, *optional*, defaults to `"AutoModel"`):
  3778. The auto class to register this new model with.
  3779. """
  3780. if not isinstance(auto_class, str):
  3781. auto_class = auto_class.__name__
  3782. import transformers.models.auto as auto_module
  3783. if not hasattr(auto_module, auto_class):
  3784. raise ValueError(f"{auto_class} is not a valid auto class.")
  3785. cls._auto_class = auto_class
  3786. def warn_if_padding_and_no_attention_mask(self, input_ids, attention_mask):
  3787. """
  3788. Shows a one-time warning if the input_ids appear to contain padding and no attention mask was given.
  3789. """
  3790. # Skip the check during tracing.
  3791. if is_tracing(input_ids):
  3792. return
  3793. if (attention_mask is not None) or (self.config.pad_token_id is None):
  3794. return
  3795. # Check only the first and last input IDs to reduce overhead.
  3796. if self.config.pad_token_id in input_ids[:, [-1, 0]]:
  3797. warn_string = (
  3798. "We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See "
  3799. "https://huggingface.co/docs/transformers/troubleshooting"
  3800. "#incorrect-output-when-padding-tokens-arent-masked."
  3801. )
  3802. # If the pad token is equal to either BOS, EOS, or SEP, we do not know whether the user should use an
  3803. # attention_mask or not. In this case, we should still show a warning because this is a rare case.
  3804. # NOTE: `sep_token_id` is not used in all models and it can be absent in the config
  3805. sep_token_id = getattr(self.config, "sep_token_id", None)
  3806. if (
  3807. (self.config.bos_token_id is not None and self.config.bos_token_id == self.config.pad_token_id)
  3808. or (self.config.eos_token_id is not None and self.config.eos_token_id == self.config.pad_token_id)
  3809. or (sep_token_id is not None and sep_token_id == self.config.pad_token_id)
  3810. ):
  3811. warn_string += (
  3812. f"\nYou may ignore this warning if your `pad_token_id` ({self.config.pad_token_id}) is identical "
  3813. f"to the `bos_token_id` ({self.config.bos_token_id}), `eos_token_id` ({self.config.eos_token_id}), "
  3814. f"or the `sep_token_id` ({sep_token_id}), and your input is not padded."
  3815. )
  3816. logger.warning_once(warn_string)
  3817. @property
  3818. def supports_tp_plan(self):
  3819. """
  3820. Returns whether the model has a tensor parallelism plan.
  3821. """
  3822. # Check if model has a TP plan
  3823. if self._tp_plan:
  3824. return True
  3825. # Check if base model has a TP plan
  3826. if self.base_model._tp_plan:
  3827. return True
  3828. # Check if config has TP plan
  3829. if self.config.base_model_tp_plan:
  3830. return True
  3831. return False
  3832. @property
  3833. def tp_size(self):
  3834. """
  3835. Returns the model's tensor parallelism degree.
  3836. """
  3837. # if None, the model didn't undergo tensor parallel sharding
  3838. return self._tp_size
  3839. @property
  3840. def supports_pp_plan(self):
  3841. # Check if model has a PP plan
  3842. if self._pp_plan:
  3843. return True
  3844. # Check if base model has PP plan
  3845. if self.base_model._pp_plan:
  3846. return True
  3847. # Check if config has PP plan
  3848. if self.config.base_model_pp_plan:
  3849. return True
  3850. return False
  3851. @property
  3852. def loss_function(self):
  3853. if hasattr(self, "_loss_function"):
  3854. return self._loss_function
  3855. loss_type = getattr(self, "loss_type", None)
  3856. if loss_type is None or loss_type not in LOSS_MAPPING:
  3857. logger.warning_once(
  3858. f"`loss_type={loss_type}` was set in the config but it is unrecognized. "
  3859. f"Using the default loss: `ForCausalLMLoss`."
  3860. )
  3861. loss_type = "ForCausalLM"
  3862. return LOSS_MAPPING[loss_type]
  3863. @loss_function.setter
  3864. def loss_function(self, value):
  3865. self._loss_function = value
  3866. def kernelize(self, mode=None):
  3867. if not is_kernels_available():
  3868. raise ValueError(
  3869. "Kernels are not available. To use kernels, please install kernels using `pip install kernels`"
  3870. )
  3871. from kernels import Device, Mode, kernelize
  3872. mode = Mode.INFERENCE if not self.training else Mode.TRAINING if mode is None else mode
  3873. kernelize(self, device=Device(type=self.device.type), mode=mode)
  3874. self._use_kernels = True
  3875. @property
  3876. def use_kernels(self) -> bool:
  3877. return getattr(self, "_use_kernels", False)
  3878. @use_kernels.setter
  3879. def use_kernels(self, value: bool) -> None:
  3880. # Avoid re-kernelizing if already enabled
  3881. if bool(value) and getattr(self, "_use_kernels", False):
  3882. return
  3883. if value:
  3884. self.kernelize()
  3885. else:
  3886. if getattr(self, "_use_kernels", False):
  3887. logger.warning_once(
  3888. "Disabling kernels at runtime is a no-op as there is no 'unkernelize' routine; keeping current kernels active."
  3889. )
  3890. self._use_kernels = False
  3891. def get_compiled_call(self, compile_config: CompileConfig | None) -> Callable:
  3892. """Return a `torch.compile`'d version of `self.__call__`. This is useful to dynamically choose between
  3893. non-compiled/compiled `forward` during inference, especially to switch between prefill (where we don't
  3894. want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding
  3895. (where we want the speed-ups of compiled version with static shapes)."""
  3896. # Only reset it if not present or different from previous config
  3897. if "llama4" in self.config.model_type: # TODO try to enable for FULL COMPILE HYBRID CACHE SUPPORT
  3898. return self.__call__
  3899. compile_config = compile_config or CompileConfig()
  3900. default_config = getattr(self.generation_config, "compile_config", None) or CompileConfig()
  3901. if (
  3902. not hasattr(self, "_compiled_call")
  3903. or getattr(self, "_last_compile_config", default_config) != compile_config
  3904. ):
  3905. self._last_compile_config = compile_config
  3906. self._compiled_call = torch.compile(self.__call__, **compile_config.to_dict())
  3907. return self._compiled_call
  3908. @classmethod
  3909. def is_backend_compatible(cls):
  3910. return cls._supports_attention_backend
  3911. def _move_missing_keys_from_meta_to_device(
  3912. self,
  3913. missing_keys: list[str],
  3914. device_map: dict | None,
  3915. device_mesh: "torch.distributed.device_mesh.DeviceMesh | None",
  3916. hf_quantizer: HfQuantizer | None,
  3917. ) -> None:
  3918. """Move the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts)
  3919. back from meta device to their device according to the `device_map` if any, else cpu. Takes care of sharding those
  3920. missing parameters if `device_mesh` is provided, i.e. we are using TP.
  3921. All non-persistent buffers are also moved back to the correct device (they are not part of the state_dict, but are
  3922. not missing either).
  3923. """
  3924. is_quantized = hf_quantizer is not None
  3925. # This is the only case where we do not initialize the model on meta device, so we don't have to do anything here
  3926. if is_deepspeed_zero3_enabled() and not is_quantized:
  3927. return
  3928. # In this case we need to move everything back
  3929. if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
  3930. for key, param in self.named_parameters():
  3931. value = torch.empty_like(param, device="cpu")
  3932. _load_parameter_into_model(self, key, value)
  3933. for key, buffer in self.named_buffers():
  3934. value = torch.empty_like(buffer, device="cpu")
  3935. _load_parameter_into_model(self, key, value)
  3936. return
  3937. # The tied weight keys are in the "missing" usually, but they should not be moved (they will be tied anyway)
  3938. # This is especially important because if they are moved, they will lose the `_is_hf_initialized` flag, and they
  3939. # will be re-initialized for nothing (which can be quite long)
  3940. for key in missing_keys - self.all_tied_weights_keys.keys():
  3941. param = self.get_parameter_or_buffer(key)
  3942. param_device = get_device(device_map, key, valid_torch_device=True)
  3943. value = torch.empty_like(param, device=param_device)
  3944. # For TP, we may need to shard the param
  3945. if device_mesh is not None:
  3946. shard_and_distribute_module(
  3947. self, value, param, key, None, False, device_mesh.get_local_rank(), device_mesh
  3948. )
  3949. # Otherwise, just move it to device
  3950. else:
  3951. _load_parameter_into_model(self, key, value)
  3952. # We need to move back non-persistent buffers as well, as they are not part of loaded weights anyway
  3953. for key, buffer in self.named_non_persistent_buffers():
  3954. buffer_device = get_device(device_map, key, valid_torch_device=True)
  3955. value = torch.empty_like(buffer, device=buffer_device)
  3956. _load_parameter_into_model(self, key, value)
  3957. def _initialize_missing_keys(self, is_quantized: bool) -> None:
  3958. """
  3959. Initialize the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts), according to
  3960. `_initialize_weights`. Indeed, since the corresponding weights are missing from the state dict, they will not be replaced and need to
  3961. be initialized correctly (i.e. weight initialization distribution).
  3962. Also marks non-missing params/buffers with `_is_hf_initialized` and propagates this flag to modules,
  3963. so that `_initialize_weights` can skip fully-initialized modules entirely.
  3964. """
  3965. if is_fsdp_enabled() and not is_local_dist_rank_0():
  3966. # Handle FSDP edge case when using cpu ram efficient loading to ensure it is marked as initialized
  3967. # since it will get its weights broadcasted from rank0
  3968. # We actually need to do that only because we want to re-initialize non-persistent buffers with correct values.
  3969. # Everything else in the state_dict will be gathered from rank0, so we don't need re-initialization.
  3970. # We could simply early return after buffer inits if we had a way to init only the non-persistent buffers
  3971. for key in self.state_dict():
  3972. try:
  3973. param_or_buffer = self.get_parameter_or_buffer(key)
  3974. param_or_buffer._is_hf_initialized = True
  3975. except AttributeError:
  3976. pass # may happen when handling pre-quantized weights
  3977. self._is_hf_initialized = True
  3978. # This will only initialize submodules that are not marked as initialized by the line above.
  3979. if is_deepspeed_zero3_enabled() and not is_quantized:
  3980. import deepspeed
  3981. # keep_vars=True as we need the original tensors, so that the "_is_hf_initialized" is present on them
  3982. not_initialized_parameters = list(
  3983. {v for v in self.state_dict(keep_vars=True).values() if not getattr(v, "_is_hf_initialized", False)}
  3984. )
  3985. with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0):
  3986. self.initialize_weights()
  3987. else:
  3988. self.initialize_weights()
  3989. def _adjust_missing_and_unexpected_keys(self, loading_info: LoadStateDictInfo) -> None:
  3990. """Adjust the `missing_keys` and `unexpected_keys` based on current model's exception rules, to avoid
  3991. raising unneeded warnings/errors. This is performed in-place.
  3992. """
  3993. # Old checkpoints may have keys for rotary_emb.inv_freq for each layer, however we moved this buffer to the main model
  3994. # (so the buffer name has changed). Remove them in such a case. This is another exception that was not added to
  3995. # `_keys_to_ignore_on_load_unexpected` as it touches many models -> we add it manually to the existing patterns
  3996. has_inv_freq_buffers = any(buffer.endswith("rotary_emb.inv_freq") for buffer, _ in self.named_buffers())
  3997. additional_unexpected_patterns = [r"rotary_emb\.inv_freq"] if has_inv_freq_buffers else []
  3998. missing_patterns = self._keys_to_ignore_on_load_missing or []
  3999. unexpected_patterns = (self._keys_to_ignore_on_load_unexpected or []) + additional_unexpected_patterns
  4000. ignore_missing_regex, ignore_unexpected_regex = None, None
  4001. if len(missing_patterns) > 0:
  4002. ignore_missing_regex = re.compile("|".join(rf"({pattern})" for pattern in missing_patterns))
  4003. if len(unexpected_patterns) > 0:
  4004. ignore_unexpected_regex = re.compile("|".join(rf"({pattern})" for pattern in unexpected_patterns))
  4005. # Clean-up missing keys
  4006. if ignore_missing_regex is not None:
  4007. loading_info.missing_keys = {
  4008. key for key in loading_info.missing_keys if ignore_missing_regex.search(key) is None
  4009. }
  4010. # Clean-up unexpected keys
  4011. if ignore_unexpected_regex is not None:
  4012. loading_info.unexpected_keys = {
  4013. key for key in loading_info.unexpected_keys if ignore_unexpected_regex.search(key) is None
  4014. }
  4015. def mark_tied_weights_as_initialized(self, loading_info):
  4016. """Adds the `_is_hf_initialized` flag on parameters that will be tied, in order to avoid initializing them
  4017. later as they will be tied (overwritten) anyway.
  4018. This is very important as most embeddings are tied, and they are huge params (vocabularies are often 256k), so
  4019. running inits on them is very costly."""
  4020. for tied_param in getattr(self, "all_tied_weights_keys", {}).keys():
  4021. param = self.get_parameter(tied_param)
  4022. param._is_hf_initialized = True
  4023. # Some remote code models define module tying (not parameter tying) in their __init__. When modules themselves are shared,
  4024. # weights inside both modules appear in the `state_dict` but only one will appear in the safetensors checkpoints
  4025. # as they are inherently tied because the 2 modules are the same object. In this case, once we load a parameter
  4026. # inside one of the 2 modules, the other will also automatically be loaded and will have the `_is_hf_initialized`
  4027. # flag (because we call `setattr` with the loaded param on the module, which is the same object), but its counterpart
  4028. # will still appear as a missing key as we never get it out of the set (because it appears in the state_dict as well).
  4029. # So we remove it now - otherwise it's considered missing and will be wrongly reinitialized
  4030. # Note: this is never an issue in main Transformers, as we never do module-tying, only parameter-tying, and we know
  4031. # which params are supposed to be tied to which other params
  4032. if self.is_remote_code():
  4033. # Remove those that are already initialized, but appear as missing due to module tying (only if they are not known
  4034. # tied weights, i.e. we did not explicitly mark them as initialized just above)
  4035. loading_info.missing_keys = {
  4036. key
  4037. for key in loading_info.missing_keys
  4038. if key in self.all_tied_weights_keys
  4039. or not getattr(self.get_parameter_or_buffer(key), "_is_hf_initialized", False)
  4040. }
  4041. def get_parameter_or_buffer(self, target: str):
  4042. """
  4043. Return the parameter or buffer given by `target` if it exists, otherwise throw an error. This combines
  4044. `get_parameter()` and `get_buffer()` in a single handy function. If the target is an `_extra_state` attribute,
  4045. it will return the extra state provided by the module. Note that it only work if `target` is a leaf of the model.
  4046. """
  4047. try:
  4048. return self.get_parameter(target)
  4049. except AttributeError:
  4050. pass
  4051. try:
  4052. return self.get_buffer(target)
  4053. except AttributeError:
  4054. pass
  4055. module, param_name = get_module_from_name(self, target)
  4056. if (
  4057. param_name == "_extra_state"
  4058. and getattr(module.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
  4059. is not torch.nn.Module.get_extra_state
  4060. ):
  4061. return module.get_extra_state()
  4062. raise AttributeError(f"`{target}` is neither a parameter, buffer, nor extra state.")
  4063. def named_non_persistent_buffers(
  4064. self, recurse: bool = True, remove_duplicate: bool = True
  4065. ) -> Iterator[tuple[str, torch.Tensor]]:
  4066. """Similar to `named_buffers`, but only yield non-persistent ones. It is handy as it's not perfectly straightforward
  4067. to know if they are persistent or not"""
  4068. for name, tensor in self.named_buffers(recurse=recurse, remove_duplicate=remove_duplicate):
  4069. # We have to grab the parent here, as the attribute `_non_persistent_buffers_set` is on the immediate
  4070. # parent only
  4071. parent, buf_name = name.rsplit(".", 1) if "." in name else ("", name)
  4072. parent = self.get_submodule(parent)
  4073. if buf_name in parent._non_persistent_buffers_set:
  4074. yield name, tensor
  4075. def train(self, mode: bool = True):
  4076. out = super().train(mode)
  4077. if self.use_kernels:
  4078. self.kernelize()
  4079. return out
  4080. def eval(self):
  4081. return self.train(False)
  4082. @classmethod
  4083. def is_remote_code(cls) -> bool:
  4084. return cls._auto_class is not None
  4085. PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
  4086. if PreTrainedModel.push_to_hub.__doc__ is not None:
  4087. PreTrainedModel.push_to_hub.__doc__ = PreTrainedModel.push_to_hub.__doc__.format(
  4088. object="model", object_class="AutoModel", object_files="model file"
  4089. )
  4090. def unwrap_model(model: nn.Module, recursive: bool = False) -> nn.Module:
  4091. """
  4092. Recursively unwraps a model from potential containers (as used in distributed training).
  4093. Args:
  4094. model (`torch.nn.Module`): The model to unwrap.
  4095. recursive (`bool`, *optional*, defaults to `False`):
  4096. Whether to recursively extract all cases of `module.module` from `model` as well as unwrap child sublayers
  4097. recursively, not just the top-level distributed containers.
  4098. """
  4099. # Use accelerate implementation if available (should always be the case when using torch)
  4100. # This is for pytorch, as we also have to handle things like dynamo
  4101. if is_accelerate_available():
  4102. kwargs = {}
  4103. if recursive:
  4104. kwargs["recursive"] = recursive
  4105. return extract_model_from_parallel(model, **kwargs)
  4106. else:
  4107. # since there could be multiple levels of wrapping, unwrap recursively
  4108. if hasattr(model, "module"):
  4109. return unwrap_model(model.module)
  4110. else:
  4111. return model
  4112. def is_accelerator_device(device: str | int | torch.device) -> bool:
  4113. """Check if the device is an accelerator. We need to function, as device_map can be "disk" as well, which is not
  4114. a proper `torch.device`.
  4115. """
  4116. if device == "disk":
  4117. return False
  4118. else:
  4119. return torch.device(device).type not in ["meta", "cpu"]
  4120. def get_total_byte_count(
  4121. model: PreTrainedModel, accelerator_device_map: dict, hf_quantizer: HfQuantizer | None = None
  4122. ):
  4123. """
  4124. This utility function calculates the total bytes count needed to load the model on each device.
  4125. This is useful for caching_allocator_warmup as we want to know how much cache we need to pre-allocate.
  4126. """
  4127. total_byte_count = defaultdict(lambda: 0)
  4128. tied_param_names = model.all_tied_weights_keys.keys()
  4129. tp_plan = model._tp_plan if torch.distributed.is_available() and torch.distributed.is_initialized() else []
  4130. for param_name, device in accelerator_device_map.items():
  4131. # Skip if the parameter has already been accounted for (tied weights)
  4132. if param_name in tied_param_names:
  4133. continue
  4134. param = model.get_parameter_or_buffer(param_name)
  4135. if hf_quantizer is not None:
  4136. dtype_size = hf_quantizer.param_element_size(model, param_name, param)
  4137. else:
  4138. dtype_size = param.element_size()
  4139. param_byte_count = param.numel() * dtype_size
  4140. if len(tp_plan) > 0:
  4141. is_part_of_plan = _get_parameter_tp_plan(param_name, tp_plan, is_weight=True) is not None
  4142. param_byte_count //= torch.distributed.get_world_size() if is_part_of_plan else 1
  4143. total_byte_count[device] += param_byte_count
  4144. return total_byte_count
  4145. def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict, hf_quantizer: HfQuantizer | None):
  4146. """This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
  4147. device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
  4148. the model, which is actually the loading speed bottleneck.
  4149. Calling this function allows to cut the model loading time by a very large margin.
  4150. A few facts related to loading speed (taking into account the use of this function):
  4151. - When loading a model the first time, it is usually slower than the subsequent times, because the OS is very likely
  4152. to cache the different state dicts (if enough resources/RAM are available)
  4153. - Trying to force the OS to cache the files in advance (by e.g. accessing a small portion of them) is really hard,
  4154. and not a good idea in general as this is low level OS optimizations that depend on resource usage anyway
  4155. - As of 18/03/2025, loading a Llama 70B model with TP takes ~1 min without file cache, and ~13s with full file cache.
  4156. The baseline, i.e. only loading the tensor shards on device and adjusting dtype (i.e. copying them) is ~5s with full cache.
  4157. These numbers are reported for TP on 4 H100 GPUs.
  4158. - It is useless to pre-allocate more than the model size in this function (i.e. using an `allocation_factor` > 1) as
  4159. cudaMalloc is not a bottleneck at all anymore
  4160. - Loading speed bottleneck is now almost only tensor copy (i.e. changing the dtype) and moving the tensors to the devices.
  4161. However, we cannot really improve on those aspects obviously, as the data needs to be moved/copied in the end.
  4162. """
  4163. # Remove disk, cpu and meta devices, and cast to proper torch.device
  4164. accelerator_device_map = {
  4165. param: torch.device(device) for param, device in expanded_device_map.items() if is_accelerator_device(device)
  4166. }
  4167. if not accelerator_device_map:
  4168. return
  4169. total_byte_count = get_total_byte_count(model, accelerator_device_map, hf_quantizer)
  4170. # This will kick off the caching allocator to avoid having to Malloc afterwards
  4171. for device, byte_count in total_byte_count.items():
  4172. if device.type in ["cuda", "xpu"]:
  4173. accelerator_module = getattr(torch, device.type)
  4174. index = device.index if device.index is not None else accelerator_module.current_device()
  4175. free_device_memory, total_device_memory = accelerator_module.mem_get_info(index)
  4176. unused_memory = accelerator_module.memory_reserved(index) - accelerator_module.memory_allocated(index)
  4177. # If we have reserved but unused memory, we can lower the allocation we want to make, but only if it's still
  4178. # higher than the unused memory. This is because otherwise torch will use that unused memory when performing
  4179. # our own allocation, thus not allocating any new memory from the GPU. For example if byte_count=6 GiB,
  4180. # unused_memory=4 GiB, then we cannot allocate only 2 GiB as this would *likely* (may not be exact, due to
  4181. # fragmentation issues) simply use the pool of 4 GiB unused memory that is available. In those cases, it's better
  4182. # to allocate more than the technically only 2 GiB required
  4183. if byte_count - unused_memory > unused_memory:
  4184. byte_count = byte_count - unused_memory
  4185. # Minimum amount that will trigger new gpu allocation, even if it's technically "too much" compared to what we need
  4186. elif byte_count - unused_memory > 1.5 * 1024**3:
  4187. # Nothing we can do here, the memory will need to fill itself as we load params, but we cannot reallocate
  4188. # from gpu until the unused memory is not filled
  4189. if unused_memory + 1 > free_device_memory:
  4190. byte_count = 0
  4191. # We allocate the minimum amount that will force new gpu allocation, even if it's technically "too much"
  4192. else:
  4193. byte_count = unused_memory + 1
  4194. # If we only need to reallocate less than 1.5 GiB of what is already allocated, then don't allocate more
  4195. else:
  4196. byte_count = 0
  4197. # Allow up to (max device memory - 1.2 GiB) in resource-constrained hardware configurations. Trying to reserve more
  4198. # than that amount might sometimes lead to unnecessary cuda/xpu OOM, if the last parameter to be loaded on the device is large,
  4199. # and the remaining reserved memory portion is smaller than the param size -> torch will then try to fully re-allocate all
  4200. # the param size, instead of using the remaining reserved part, and allocating only the difference, which can lead
  4201. # to OOM. See https://github.com/huggingface/transformers/issues/37436#issuecomment-2808982161 for more details.
  4202. # Note that we use an absolute value instead of device proportion here, as a 8GiB device could still allocate too much
  4203. # if using e.g. 90% of device size, while a 140GiB device would allocate too little
  4204. byte_count = min(byte_count, total_device_memory - 1.2 * 1024**3)
  4205. # We divide by 2 here as we allocate in fp16
  4206. _ = torch.empty(int(byte_count // 2), dtype=torch.float16, device=device, requires_grad=False)
  4207. class AttentionInterface(GeneralInterface):
  4208. """
  4209. Dict-like object keeping track of allowed attention functions. You can easily add a new attention function
  4210. with a call to `register()`. If a model needs to locally overwrite an existing attention function, say `sdpa`,
  4211. it needs to declare a new instance of this class inside the `modeling_<model>.py`, and declare it on that instance.
  4212. """
  4213. # Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
  4214. # a new instance is created (in order to locally override a given function)
  4215. _global_mapping = {
  4216. "flash_attention_4": flash_attention_forward,
  4217. "flash_attention_3": flash_attention_forward,
  4218. "flash_attention_2": flash_attention_forward,
  4219. "flex_attention": flex_attention_forward,
  4220. "sdpa": sdpa_attention_forward,
  4221. "paged|flash_attention_4": paged_attention_forward,
  4222. "paged|flash_attention_3": paged_attention_forward,
  4223. "paged|flash_attention_2": paged_attention_forward,
  4224. "paged|sdpa": sdpa_attention_paged_forward,
  4225. "paged|eager": eager_paged_attention_forward,
  4226. }
  4227. def get_interface(self, attn_implementation: str, default: Callable) -> Callable:
  4228. """Return the requested `attn_implementation`. Also strictly check its validity, and raise if invalid."""
  4229. if attn_implementation is None:
  4230. logger.warning_once(
  4231. "You tried to access the `AttentionInterface` with a `config._attn_implementation` set to `None`. This "
  4232. "is expected if you use an Attention Module as a standalone Module. If this is not the case, something went "
  4233. "wrong with the dispatch of `config._attn_implementation`"
  4234. )
  4235. elif attn_implementation != "eager" and attn_implementation not in self:
  4236. raise KeyError(
  4237. f"`{attn_implementation}` is not a valid attention implementation registered in the `AttentionInterface`"
  4238. )
  4239. return super().get(attn_implementation, default)
  4240. # Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones
  4241. ALL_ATTENTION_FUNCTIONS: AttentionInterface = AttentionInterface()
  4242. class PreTrainedAudioTokenizerBase(PreTrainedModel):
  4243. """
  4244. Class that additionally defines the behavior of any `audio_tokenizer` to be added.
  4245. Characteristic for any of them:
  4246. 1. Encode raw audio into discrete audio codebooks (with x channels)
  4247. 2. Decode from discrete audio codebooks back to raw audio
  4248. It is possible that they can decode in different ways given a different representation
  4249. but they are forced to support 2. nonetheless, e.g. see `DAC`.
  4250. """
  4251. @abstractmethod
  4252. def encode(self, input_values: torch.Tensor, *args, **kwargs):
  4253. """
  4254. Encode raw audio retrieved from a respective `FeatureExtractor` into discrete audio codebooks (with x channels)
  4255. """
  4256. @abstractmethod
  4257. def decode(self, audio_codes: torch.Tensor, *args, **kwargs):
  4258. """Decode from discrete audio codebooks back to raw audio"""