modeling_auto.py 97 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376
  1. # Copyright 2018 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Auto Model class."""
  15. import os
  16. from collections import OrderedDict
  17. from typing import TYPE_CHECKING
  18. from ...utils import logging
  19. from .auto_factory import (
  20. _BaseAutoBackboneClass,
  21. _BaseAutoModelClass,
  22. _LazyAutoMapping,
  23. auto_class_update,
  24. )
  25. from .configuration_auto import CONFIG_MAPPING_NAMES
  26. if TYPE_CHECKING:
  27. from ...generation import GenerationMixin
  28. from ...modeling_utils import PreTrainedModel
  29. # class for better type annotations
  30. class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
  31. pass
  32. logger = logging.get_logger(__name__)
  33. MODEL_MAPPING_NAMES = OrderedDict(
  34. [
  35. # Base model mapping
  36. ("afmoe", "AfmoeModel"),
  37. ("aimv2", "Aimv2Model"),
  38. ("aimv2_vision_model", "Aimv2VisionModel"),
  39. ("albert", "AlbertModel"),
  40. ("align", "AlignModel"),
  41. ("altclip", "AltCLIPModel"),
  42. ("apertus", "ApertusModel"),
  43. ("arcee", "ArceeModel"),
  44. ("aria", "AriaModel"),
  45. ("aria_text", "AriaTextModel"),
  46. ("audio-spectrogram-transformer", "ASTModel"),
  47. ("audioflamingo3", "AudioFlamingo3ForConditionalGeneration"),
  48. ("audioflamingo3_encoder", "AudioFlamingo3Encoder"),
  49. ("autoformer", "AutoformerModel"),
  50. ("aya_vision", "AyaVisionModel"),
  51. ("bamba", "BambaModel"),
  52. ("bark", "BarkModel"),
  53. ("bart", "BartModel"),
  54. ("beit", "BeitModel"),
  55. ("bert", "BertModel"),
  56. ("bert-generation", "BertGenerationEncoder"),
  57. ("big_bird", "BigBirdModel"),
  58. ("bigbird_pegasus", "BigBirdPegasusModel"),
  59. ("biogpt", "BioGptModel"),
  60. ("bit", "BitModel"),
  61. ("bitnet", "BitNetModel"),
  62. ("blenderbot", "BlenderbotModel"),
  63. ("blenderbot-small", "BlenderbotSmallModel"),
  64. ("blip", "BlipModel"),
  65. ("blip-2", "Blip2Model"),
  66. ("blip_2_qformer", "Blip2QFormerModel"),
  67. ("bloom", "BloomModel"),
  68. ("blt", "BltModel"),
  69. ("bridgetower", "BridgeTowerModel"),
  70. ("bros", "BrosModel"),
  71. ("camembert", "CamembertModel"),
  72. ("canine", "CanineModel"),
  73. ("chameleon", "ChameleonModel"),
  74. ("chinese_clip", "ChineseCLIPModel"),
  75. ("chinese_clip_vision_model", "ChineseCLIPVisionModel"),
  76. ("clap", "ClapModel"),
  77. ("clip", "CLIPModel"),
  78. ("clip_text_model", "CLIPTextModel"),
  79. ("clip_vision_model", "CLIPVisionModel"),
  80. ("clipseg", "CLIPSegModel"),
  81. ("clvp", "ClvpModelForConditionalGeneration"),
  82. ("code_llama", "LlamaModel"),
  83. ("codegen", "CodeGenModel"),
  84. ("cohere", "CohereModel"),
  85. ("cohere2", "Cohere2Model"),
  86. ("cohere2_vision", "Cohere2VisionModel"),
  87. ("cohere_asr", "CohereAsrModel"),
  88. ("conditional_detr", "ConditionalDetrModel"),
  89. ("convbert", "ConvBertModel"),
  90. ("convnext", "ConvNextModel"),
  91. ("convnextv2", "ConvNextV2Model"),
  92. ("cpmant", "CpmAntModel"),
  93. ("csm", "CsmForConditionalGeneration"),
  94. ("ctrl", "CTRLModel"),
  95. ("cvt", "CvtModel"),
  96. ("cwm", "CwmModel"),
  97. ("d_fine", "DFineModel"),
  98. ("dab-detr", "DabDetrModel"),
  99. ("dac", "DacModel"),
  100. ("data2vec-audio", "Data2VecAudioModel"),
  101. ("data2vec-text", "Data2VecTextModel"),
  102. ("data2vec-vision", "Data2VecVisionModel"),
  103. ("dbrx", "DbrxModel"),
  104. ("deberta", "DebertaModel"),
  105. ("deberta-v2", "DebertaV2Model"),
  106. ("decision_transformer", "DecisionTransformerModel"),
  107. ("deepseek_v2", "DeepseekV2Model"),
  108. ("deepseek_v3", "DeepseekV3Model"),
  109. ("deepseek_vl", "DeepseekVLModel"),
  110. ("deepseek_vl_hybrid", "DeepseekVLHybridModel"),
  111. ("deformable_detr", "DeformableDetrModel"),
  112. ("deit", "DeiTModel"),
  113. ("depth_pro", "DepthProModel"),
  114. ("detr", "DetrModel"),
  115. ("dia", "DiaModel"),
  116. ("diffllama", "DiffLlamaModel"),
  117. ("dinat", "DinatModel"),
  118. ("dinov2", "Dinov2Model"),
  119. ("dinov2_with_registers", "Dinov2WithRegistersModel"),
  120. ("dinov3_convnext", "DINOv3ConvNextModel"),
  121. ("dinov3_vit", "DINOv3ViTModel"),
  122. ("distilbert", "DistilBertModel"),
  123. ("doge", "DogeModel"),
  124. ("donut-swin", "DonutSwinModel"),
  125. ("dots1", "Dots1Model"),
  126. ("dpr", "DPRQuestionEncoder"),
  127. ("dpt", "DPTModel"),
  128. ("edgetam", "EdgeTamModel"),
  129. ("edgetam_video", "EdgeTamVideoModel"),
  130. ("edgetam_vision_model", "EdgeTamVisionModel"),
  131. ("efficientloftr", "EfficientLoFTRModel"),
  132. ("efficientnet", "EfficientNetModel"),
  133. ("electra", "ElectraModel"),
  134. ("emu3", "Emu3Model"),
  135. ("encodec", "EncodecModel"),
  136. ("ernie", "ErnieModel"),
  137. ("ernie4_5", "Ernie4_5Model"),
  138. ("ernie4_5_moe", "Ernie4_5_MoeModel"),
  139. ("ernie4_5_vl_moe", "Ernie4_5_VLMoeModel"),
  140. ("esm", "EsmModel"),
  141. ("eurobert", "EuroBertModel"),
  142. ("evolla", "EvollaModel"),
  143. ("exaone4", "Exaone4Model"),
  144. ("exaone_moe", "ExaoneMoeModel"),
  145. ("falcon", "FalconModel"),
  146. ("falcon_h1", "FalconH1Model"),
  147. ("falcon_mamba", "FalconMambaModel"),
  148. ("fast_vlm", "FastVlmModel"),
  149. ("fastspeech2_conformer", "FastSpeech2ConformerModel"),
  150. ("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"),
  151. ("flaubert", "FlaubertModel"),
  152. ("flava", "FlavaModel"),
  153. ("flex_olmo", "FlexOlmoModel"),
  154. ("florence2", "Florence2Model"),
  155. ("fnet", "FNetModel"),
  156. ("focalnet", "FocalNetModel"),
  157. ("fsmt", "FSMTModel"),
  158. ("funnel", ("FunnelModel", "FunnelBaseModel")),
  159. ("fuyu", "FuyuModel"),
  160. ("gemma", "GemmaModel"),
  161. ("gemma2", "Gemma2Model"),
  162. ("gemma3", "Gemma3Model"),
  163. ("gemma3_text", "Gemma3TextModel"),
  164. ("gemma3n", "Gemma3nModel"),
  165. ("gemma3n_audio", "Gemma3nAudioEncoder"),
  166. ("gemma3n_text", "Gemma3nTextModel"),
  167. ("gemma3n_vision", "TimmWrapperModel"),
  168. ("gemma4", "Gemma4Model"),
  169. ("gemma4_audio", "Gemma4AudioModel"),
  170. ("gemma4_text", "Gemma4TextModel"),
  171. ("gemma4_vision", "Gemma4VisionModel"),
  172. ("git", "GitModel"),
  173. ("glm", "GlmModel"),
  174. ("glm4", "Glm4Model"),
  175. ("glm46v", "Glm46VModel"),
  176. ("glm4_moe", "Glm4MoeModel"),
  177. ("glm4_moe_lite", "Glm4MoeLiteModel"),
  178. ("glm4v", "Glm4vModel"),
  179. ("glm4v_moe", "Glm4vMoeModel"),
  180. ("glm4v_moe_text", "Glm4vMoeTextModel"),
  181. ("glm4v_moe_vision", "Glm4vMoeVisionModel"),
  182. ("glm4v_text", "Glm4vTextModel"),
  183. ("glm4v_vision", "Glm4vVisionModel"),
  184. ("glm_image", "GlmImageModel"),
  185. ("glm_image_text", "GlmImageTextModel"),
  186. ("glm_image_vision", "GlmImageVisionModel"),
  187. ("glm_image_vqmodel", "GlmImageVQVAE"),
  188. ("glm_moe_dsa", "GlmMoeDsaModel"),
  189. ("glm_ocr", "GlmOcrModel"),
  190. ("glm_ocr_text", "GlmOcrTextModel"),
  191. ("glm_ocr_vision", "GlmOcrVisionModel"),
  192. ("glmasr", "GlmAsrForConditionalGeneration"),
  193. ("glmasr_encoder", "GlmAsrEncoder"),
  194. ("glpn", "GLPNModel"),
  195. ("got_ocr2", "GotOcr2Model"),
  196. ("gpt-sw3", "GPT2Model"),
  197. ("gpt2", "GPT2Model"),
  198. ("gpt_bigcode", "GPTBigCodeModel"),
  199. ("gpt_neo", "GPTNeoModel"),
  200. ("gpt_neox", "GPTNeoXModel"),
  201. ("gpt_neox_japanese", "GPTNeoXJapaneseModel"),
  202. ("gpt_oss", "GptOssModel"),
  203. ("gptj", "GPTJModel"),
  204. ("granite", "GraniteModel"),
  205. ("granitemoe", "GraniteMoeModel"),
  206. ("granitemoehybrid", "GraniteMoeHybridModel"),
  207. ("granitemoeshared", "GraniteMoeSharedModel"),
  208. ("grounding-dino", "GroundingDinoModel"),
  209. ("groupvit", "GroupViTModel"),
  210. ("helium", "HeliumModel"),
  211. ("hgnet_v2", "HGNetV2Backbone"),
  212. ("hiera", "HieraModel"),
  213. ("higgs_audio_v2", "HiggsAudioV2ForConditionalGeneration"),
  214. ("higgs_audio_v2_tokenizer", "HiggsAudioV2TokenizerModel"),
  215. ("hubert", "HubertModel"),
  216. ("hunyuan_v1_dense", "HunYuanDenseV1Model"),
  217. ("hunyuan_v1_moe", "HunYuanMoEV1Model"),
  218. ("ibert", "IBertModel"),
  219. ("idefics", "IdeficsModel"),
  220. ("idefics2", "Idefics2Model"),
  221. ("idefics3", "Idefics3Model"),
  222. ("idefics3_vision", "Idefics3VisionTransformer"),
  223. ("ijepa", "IJepaModel"),
  224. ("imagegpt", "ImageGPTModel"),
  225. ("informer", "InformerModel"),
  226. ("instructblip", "InstructBlipModel"),
  227. ("instructblipvideo", "InstructBlipVideoModel"),
  228. ("internvl", "InternVLModel"),
  229. ("internvl_vision", "InternVLVisionModel"),
  230. ("jais2", "Jais2Model"),
  231. ("jamba", "JambaModel"),
  232. ("janus", "JanusModel"),
  233. ("jetmoe", "JetMoeModel"),
  234. ("jina_embeddings_v3", "JinaEmbeddingsV3Model"),
  235. ("kosmos-2", "Kosmos2Model"),
  236. ("kosmos-2.5", "Kosmos2_5Model"),
  237. ("kyutai_speech_to_text", "KyutaiSpeechToTextModel"),
  238. ("lasr_ctc", "LasrForCTC"),
  239. ("lasr_encoder", "LasrEncoder"),
  240. ("layoutlm", "LayoutLMModel"),
  241. ("layoutlmv2", "LayoutLMv2Model"),
  242. ("layoutlmv3", "LayoutLMv3Model"),
  243. ("led", "LEDModel"),
  244. ("levit", "LevitModel"),
  245. ("lfm2", "Lfm2Model"),
  246. ("lfm2_moe", "Lfm2MoeModel"),
  247. ("lfm2_vl", "Lfm2VlModel"),
  248. ("lightglue", "LightGlueForKeypointMatching"),
  249. ("lighton_ocr", "LightOnOcrModel"),
  250. ("lilt", "LiltModel"),
  251. ("llama", "LlamaModel"),
  252. ("llama4", "Llama4ForConditionalGeneration"),
  253. ("llama4_text", "Llama4TextModel"),
  254. ("llava", "LlavaModel"),
  255. ("llava_next", "LlavaNextModel"),
  256. ("llava_next_video", "LlavaNextVideoModel"),
  257. ("llava_onevision", "LlavaOnevisionModel"),
  258. ("longcat_flash", "LongcatFlashModel"),
  259. ("longformer", "LongformerModel"),
  260. ("longt5", "LongT5Model"),
  261. ("luke", "LukeModel"),
  262. ("lw_detr", "LwDetrModel"),
  263. ("lxmert", "LxmertModel"),
  264. ("m2m_100", "M2M100Model"),
  265. ("mamba", "MambaModel"),
  266. ("mamba2", "Mamba2Model"),
  267. ("marian", "MarianModel"),
  268. ("markuplm", "MarkupLMModel"),
  269. ("mask2former", "Mask2FormerModel"),
  270. ("maskformer", "MaskFormerModel"),
  271. ("maskformer-swin", "MaskFormerSwinModel"),
  272. ("mbart", "MBartModel"),
  273. ("megatron-bert", "MegatronBertModel"),
  274. ("metaclip_2", "MetaClip2Model"),
  275. ("mgp-str", "MgpstrForSceneTextRecognition"),
  276. ("mimi", "MimiModel"),
  277. ("minimax", "MiniMaxModel"),
  278. ("minimax_m2", "MiniMaxM2Model"),
  279. ("ministral", "MinistralModel"),
  280. ("ministral3", "Ministral3Model"),
  281. ("mistral", "MistralModel"),
  282. ("mistral3", "Mistral3Model"),
  283. ("mistral4", "Mistral4Model"),
  284. ("mixtral", "MixtralModel"),
  285. ("mlcd", "MLCDVisionModel"), # Keep this to make some original hub repositories (from `DeepGlint-AI`) works
  286. ("mlcd_vision_model", "MLCDVisionModel"),
  287. ("mllama", "MllamaModel"),
  288. ("mm-grounding-dino", "MMGroundingDinoModel"),
  289. ("mobilebert", "MobileBertModel"),
  290. ("mobilenet_v1", "MobileNetV1Model"),
  291. ("mobilenet_v2", "MobileNetV2Model"),
  292. ("mobilevit", "MobileViTModel"),
  293. ("mobilevitv2", "MobileViTV2Model"),
  294. ("modernbert", "ModernBertModel"),
  295. ("modernbert-decoder", "ModernBertDecoderModel"),
  296. ("modernvbert", "ModernVBertModel"),
  297. ("moonshine", "MoonshineModel"),
  298. ("moonshine_streaming", "MoonshineStreamingModel"),
  299. ("moshi", "MoshiModel"),
  300. ("mpnet", "MPNetModel"),
  301. ("mpt", "MptModel"),
  302. ("mra", "MraModel"),
  303. ("mt5", "MT5Model"),
  304. ("musicflamingo", "MusicFlamingoForConditionalGeneration"),
  305. ("musicflamingo_encoder", "AudioFlamingo3Encoder"),
  306. ("musicgen", "MusicgenModel"),
  307. ("musicgen_melody", "MusicgenMelodyModel"),
  308. ("mvp", "MvpModel"),
  309. ("nanochat", "NanoChatModel"),
  310. ("nemotron", "NemotronModel"),
  311. ("nemotron_h", "NemotronHModel"),
  312. ("nllb-moe", "NllbMoeModel"),
  313. ("nomic_bert", "NomicBertModel"),
  314. ("nystromformer", "NystromformerModel"),
  315. ("olmo", "OlmoModel"),
  316. ("olmo2", "Olmo2Model"),
  317. ("olmo3", "Olmo3Model"),
  318. ("olmo_hybrid", "OlmoHybridModel"),
  319. ("olmoe", "OlmoeModel"),
  320. ("omdet-turbo", "OmDetTurboForObjectDetection"),
  321. ("oneformer", "OneFormerModel"),
  322. ("openai-gpt", "OpenAIGPTModel"),
  323. ("opt", "OPTModel"),
  324. ("ovis2", "Ovis2Model"),
  325. ("owlv2", "Owlv2Model"),
  326. ("owlvit", "OwlViTModel"),
  327. ("paligemma", "PaliGemmaModel"),
  328. ("parakeet_ctc", "ParakeetForCTC"),
  329. ("parakeet_encoder", "ParakeetEncoder"),
  330. ("patchtsmixer", "PatchTSMixerModel"),
  331. ("patchtst", "PatchTSTModel"),
  332. ("pe_audio", "PeAudioModel"),
  333. ("pe_audio_encoder", "PeAudioEncoder"),
  334. ("pe_audio_video", "PeAudioVideoModel"),
  335. ("pe_audio_video_encoder", "PeAudioVideoEncoder"),
  336. ("pe_video", "PeVideoModel"),
  337. ("pe_video_encoder", "PeVideoEncoder"),
  338. ("pegasus", "PegasusModel"),
  339. ("pegasus_x", "PegasusXModel"),
  340. ("perceiver", "PerceiverModel"),
  341. ("perception_lm", "PerceptionLMModel"),
  342. ("persimmon", "PersimmonModel"),
  343. ("phi", "PhiModel"),
  344. ("phi3", "Phi3Model"),
  345. ("phi4_multimodal", "Phi4MultimodalModel"),
  346. ("phimoe", "PhimoeModel"),
  347. ("pi0", "PI0Model"),
  348. ("pixio", "PixioModel"),
  349. ("pixtral", "PixtralVisionModel"),
  350. ("plbart", "PLBartModel"),
  351. ("poolformer", "PoolFormerModel"),
  352. ("pp_doclayout_v3", "PPDocLayoutV3Model"),
  353. ("pp_ocrv5_mobile_rec", "PPOCRV5MobileRecModel"),
  354. ("pp_ocrv5_server_rec", "PPOCRV5ServerRecModel"),
  355. ("prophetnet", "ProphetNetModel"),
  356. ("pvt", "PvtModel"),
  357. ("pvt_v2", "PvtV2Model"),
  358. ("qwen2", "Qwen2Model"),
  359. ("qwen2_5_vl", "Qwen2_5_VLModel"),
  360. ("qwen2_5_vl_text", "Qwen2_5_VLTextModel"),
  361. ("qwen2_audio_encoder", "Qwen2AudioEncoder"),
  362. ("qwen2_moe", "Qwen2MoeModel"),
  363. ("qwen2_vl", "Qwen2VLModel"),
  364. ("qwen2_vl_text", "Qwen2VLTextModel"),
  365. ("qwen3", "Qwen3Model"),
  366. ("qwen3_5", "Qwen3_5Model"),
  367. ("qwen3_5_moe", "Qwen3_5MoeModel"),
  368. ("qwen3_5_moe_text", "Qwen3_5MoeTextModel"),
  369. ("qwen3_5_text", "Qwen3_5TextModel"),
  370. ("qwen3_moe", "Qwen3MoeModel"),
  371. ("qwen3_next", "Qwen3NextModel"),
  372. ("qwen3_vl", "Qwen3VLModel"),
  373. ("qwen3_vl_moe", "Qwen3VLMoeModel"),
  374. ("qwen3_vl_moe_text", "Qwen3VLMoeTextModel"),
  375. ("qwen3_vl_text", "Qwen3VLTextModel"),
  376. ("recurrent_gemma", "RecurrentGemmaModel"),
  377. ("reformer", "ReformerModel"),
  378. ("regnet", "RegNetModel"),
  379. ("rembert", "RemBertModel"),
  380. ("resnet", "ResNetModel"),
  381. ("roberta", "RobertaModel"),
  382. ("roberta-prelayernorm", "RobertaPreLayerNormModel"),
  383. ("roc_bert", "RoCBertModel"),
  384. ("roformer", "RoFormerModel"),
  385. ("rt_detr", "RTDetrModel"),
  386. ("rt_detr_v2", "RTDetrV2Model"),
  387. ("rwkv", "RwkvModel"),
  388. ("sam", "SamModel"),
  389. ("sam2", "Sam2Model"),
  390. ("sam2_hiera_det_model", "Sam2HieraDetModel"),
  391. ("sam2_video", "Sam2VideoModel"),
  392. ("sam2_vision_model", "Sam2VisionModel"),
  393. ("sam3", "Sam3Model"),
  394. ("sam3_tracker", "Sam3TrackerModel"),
  395. ("sam3_tracker", "Sam3TrackerModel"),
  396. ("sam3_tracker_video", "Sam3TrackerVideoModel"),
  397. ("sam3_video", "Sam3VideoModel"),
  398. ("sam3_vision_model", "Sam3VisionModel"),
  399. ("sam3_vit_model", "Sam3ViTModel"),
  400. ("sam_hq", "SamHQModel"),
  401. ("sam_hq_vision_model", "SamHQVisionModel"),
  402. ("sam_vision_model", "SamVisionModel"),
  403. ("seamless_m4t", "SeamlessM4TModel"),
  404. ("seamless_m4t_v2", "SeamlessM4Tv2Model"),
  405. ("seed_oss", "SeedOssModel"),
  406. ("segformer", "SegformerModel"),
  407. ("seggpt", "SegGptModel"),
  408. ("sew", "SEWModel"),
  409. ("sew-d", "SEWDModel"),
  410. ("siglip", "SiglipModel"),
  411. ("siglip2", "Siglip2Model"),
  412. ("siglip2_vision_model", "Siglip2VisionModel"),
  413. ("siglip_vision_model", "SiglipVisionModel"),
  414. ("smollm3", "SmolLM3Model"),
  415. ("smolvlm", "SmolVLMModel"),
  416. ("smolvlm_vision", "SmolVLMVisionTransformer"),
  417. ("solar_open", "SolarOpenModel"),
  418. ("speech_to_text", "Speech2TextModel"),
  419. ("speecht5", "SpeechT5Model"),
  420. ("splinter", "SplinterModel"),
  421. ("squeezebert", "SqueezeBertModel"),
  422. ("stablelm", "StableLmModel"),
  423. ("starcoder2", "Starcoder2Model"),
  424. ("swiftformer", "SwiftFormerModel"),
  425. ("swin", "SwinModel"),
  426. ("swin2sr", "Swin2SRModel"),
  427. ("swinv2", "Swinv2Model"),
  428. ("switch_transformers", "SwitchTransformersModel"),
  429. ("t5", "T5Model"),
  430. ("t5gemma", "T5GemmaModel"),
  431. ("t5gemma2", "T5Gemma2Model"),
  432. ("t5gemma2_encoder", "T5Gemma2Encoder"),
  433. ("table-transformer", "TableTransformerModel"),
  434. ("tapas", "TapasModel"),
  435. ("textnet", "TextNetModel"),
  436. ("time_series_transformer", "TimeSeriesTransformerModel"),
  437. ("timesfm", "TimesFmModel"),
  438. ("timesfm2_5", "TimesFm2_5Model"),
  439. ("timesformer", "TimesformerModel"),
  440. ("timm_backbone", "TimmBackbone"),
  441. ("timm_wrapper", "TimmWrapperModel"),
  442. ("tvp", "TvpModel"),
  443. ("udop", "UdopModel"),
  444. ("umt5", "UMT5Model"),
  445. ("unispeech", "UniSpeechModel"),
  446. ("unispeech-sat", "UniSpeechSatModel"),
  447. ("univnet", "UnivNetModel"),
  448. ("uvdoc", "UVDocModel"),
  449. ("vaultgemma", "VaultGemmaModel"),
  450. ("vibevoice_acoustic_tokenizer", "VibeVoiceAcousticTokenizerModel"),
  451. ("vibevoice_acoustic_tokenizer_decoder", "VibeVoiceAcousticTokenizerDecoderModel"),
  452. ("vibevoice_acoustic_tokenizer_encoder", "VibeVoiceAcousticTokenizerEncoderModel"),
  453. ("vibevoice_asr", "VibeVoiceAsrForConditionalGeneration"),
  454. ("video_llama_3", "VideoLlama3Model"),
  455. ("video_llama_3_vision", "VideoLlama3VisionModel"),
  456. ("video_llava", "VideoLlavaModel"),
  457. ("videomae", "VideoMAEModel"),
  458. ("vilt", "ViltModel"),
  459. ("vipllava", "VipLlavaModel"),
  460. ("vision-text-dual-encoder", "VisionTextDualEncoderModel"),
  461. ("visual_bert", "VisualBertModel"),
  462. ("vit", "ViTModel"),
  463. ("vit_mae", "ViTMAEModel"),
  464. ("vit_msn", "ViTMSNModel"),
  465. ("vitdet", "VitDetModel"),
  466. ("vits", "VitsModel"),
  467. ("vivit", "VivitModel"),
  468. ("vjepa2", "VJEPA2Model"),
  469. ("voxtral", "VoxtralForConditionalGeneration"),
  470. ("voxtral_encoder", "VoxtralEncoder"),
  471. ("voxtral_realtime", "VoxtralRealtimeForConditionalGeneration"),
  472. ("voxtral_realtime_encoder", "VoxtralRealtimeEncoder"),
  473. ("voxtral_realtime_text", "VoxtralRealtimeTextModel"),
  474. ("wav2vec2", "Wav2Vec2Model"),
  475. ("wav2vec2-bert", "Wav2Vec2BertModel"),
  476. ("wav2vec2-conformer", "Wav2Vec2ConformerModel"),
  477. ("wavlm", "WavLMModel"),
  478. ("whisper", "WhisperModel"),
  479. ("xclip", "XCLIPModel"),
  480. ("xcodec", "XcodecModel"),
  481. ("xglm", "XGLMModel"),
  482. ("xlm", "XLMModel"),
  483. ("xlm-roberta", "XLMRobertaModel"),
  484. ("xlm-roberta-xl", "XLMRobertaXLModel"),
  485. ("xlnet", "XLNetModel"),
  486. ("xlstm", "xLSTMModel"),
  487. ("xmod", "XmodModel"),
  488. ("yolos", "YolosModel"),
  489. ("yoso", "YosoModel"),
  490. ("youtu", "YoutuModel"),
  491. ("zamba", "ZambaModel"),
  492. ("zamba2", "Zamba2Model"),
  493. ]
  494. )
  495. MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
  496. [
  497. # Model for pre-training mapping
  498. ("albert", "AlbertForPreTraining"),
  499. ("audioflamingo3", "AudioFlamingo3ForConditionalGeneration"),
  500. ("bart", "BartForConditionalGeneration"),
  501. ("bert", "BertForPreTraining"),
  502. ("big_bird", "BigBirdForPreTraining"),
  503. ("bloom", "BloomForCausalLM"),
  504. ("camembert", "CamembertForMaskedLM"),
  505. ("colmodernvbert", "ColModernVBertForRetrieval"),
  506. ("colpali", "ColPaliForRetrieval"),
  507. ("colqwen2", "ColQwen2ForRetrieval"),
  508. ("ctrl", "CTRLLMHeadModel"),
  509. ("data2vec-text", "Data2VecTextForMaskedLM"),
  510. ("deberta", "DebertaForMaskedLM"),
  511. ("deberta-v2", "DebertaV2ForMaskedLM"),
  512. ("distilbert", "DistilBertForMaskedLM"),
  513. ("electra", "ElectraForPreTraining"),
  514. ("ernie", "ErnieForPreTraining"),
  515. ("evolla", "EvollaForProteinText2Text"),
  516. ("exaone4", "Exaone4ForCausalLM"),
  517. ("exaone_moe", "ExaoneMoeForCausalLM"),
  518. ("falcon_mamba", "FalconMambaForCausalLM"),
  519. ("flaubert", "FlaubertWithLMHeadModel"),
  520. ("flava", "FlavaForPreTraining"),
  521. ("florence2", "Florence2ForConditionalGeneration"),
  522. ("fnet", "FNetForPreTraining"),
  523. ("fsmt", "FSMTForConditionalGeneration"),
  524. ("funnel", "FunnelForPreTraining"),
  525. ("gemma3", "Gemma3ForConditionalGeneration"),
  526. ("gemma4", "Gemma4ForConditionalGeneration"),
  527. ("glmasr", "GlmAsrForConditionalGeneration"),
  528. ("gpt-sw3", "GPT2LMHeadModel"),
  529. ("gpt2", "GPT2LMHeadModel"),
  530. ("gpt_bigcode", "GPTBigCodeForCausalLM"),
  531. ("hiera", "HieraForPreTraining"),
  532. ("ibert", "IBertForMaskedLM"),
  533. ("idefics", "IdeficsForVisionText2Text"),
  534. ("idefics2", "Idefics2ForConditionalGeneration"),
  535. ("idefics3", "Idefics3ForConditionalGeneration"),
  536. ("janus", "JanusForConditionalGeneration"),
  537. ("layoutlm", "LayoutLMForMaskedLM"),
  538. ("llava", "LlavaForConditionalGeneration"),
  539. ("llava_next", "LlavaNextForConditionalGeneration"),
  540. ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
  541. ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
  542. ("longformer", "LongformerForMaskedLM"),
  543. ("luke", "LukeForMaskedLM"),
  544. ("lxmert", "LxmertForPreTraining"),
  545. ("mamba", "MambaForCausalLM"),
  546. ("mamba2", "Mamba2ForCausalLM"),
  547. ("megatron-bert", "MegatronBertForPreTraining"),
  548. ("mistral3", "Mistral3ForConditionalGeneration"),
  549. ("mistral4", "Mistral4ForCausalLM"),
  550. ("mllama", "MllamaForConditionalGeneration"),
  551. ("mobilebert", "MobileBertForPreTraining"),
  552. ("mpnet", "MPNetForMaskedLM"),
  553. ("mpt", "MptForCausalLM"),
  554. ("mra", "MraForMaskedLM"),
  555. ("musicflamingo", "MusicFlamingoForConditionalGeneration"),
  556. ("mvp", "MvpForConditionalGeneration"),
  557. ("nanochat", "NanoChatForCausalLM"),
  558. ("nllb-moe", "NllbMoeForConditionalGeneration"),
  559. ("openai-gpt", "OpenAIGPTLMHeadModel"),
  560. ("paligemma", "PaliGemmaForConditionalGeneration"),
  561. ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),
  562. ("roberta", "RobertaForMaskedLM"),
  563. ("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"),
  564. ("roc_bert", "RoCBertForPreTraining"),
  565. ("rwkv", "RwkvForCausalLM"),
  566. ("splinter", "SplinterForPreTraining"),
  567. ("squeezebert", "SqueezeBertForMaskedLM"),
  568. ("switch_transformers", "SwitchTransformersForConditionalGeneration"),
  569. ("t5", "T5ForConditionalGeneration"),
  570. ("t5gemma", "T5GemmaForConditionalGeneration"),
  571. ("t5gemma2", "T5Gemma2ForConditionalGeneration"),
  572. ("tapas", "TapasForMaskedLM"),
  573. ("unispeech", "UniSpeechForPreTraining"),
  574. ("unispeech-sat", "UniSpeechSatForPreTraining"),
  575. ("vibevoice_asr", "VibeVoiceAsrForConditionalGeneration"),
  576. ("video_llava", "VideoLlavaForConditionalGeneration"),
  577. ("videomae", "VideoMAEForPreTraining"),
  578. ("vipllava", "VipLlavaForConditionalGeneration"),
  579. ("visual_bert", "VisualBertForPreTraining"),
  580. ("vit_mae", "ViTMAEForPreTraining"),
  581. ("voxtral", "VoxtralForConditionalGeneration"),
  582. ("voxtral_realtime", "VoxtralRealtimeForConditionalGeneration"),
  583. ("wav2vec2", "Wav2Vec2ForPreTraining"),
  584. ("wav2vec2-conformer", "Wav2Vec2ConformerForPreTraining"),
  585. ("xlm", "XLMWithLMHeadModel"),
  586. ("xlm-roberta", "XLMRobertaForMaskedLM"),
  587. ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
  588. ("xlnet", "XLNetLMHeadModel"),
  589. ("xlstm", "xLSTMForCausalLM"),
  590. ("xmod", "XmodForMaskedLM"),
  591. ]
  592. )
  593. MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
  594. [
  595. # Model for Causal LM mapping
  596. ("afmoe", "AfmoeForCausalLM"),
  597. ("apertus", "ApertusForCausalLM"),
  598. ("arcee", "ArceeForCausalLM"),
  599. ("aria_text", "AriaTextForCausalLM"),
  600. ("bamba", "BambaForCausalLM"),
  601. ("bart", "BartForCausalLM"),
  602. ("bert", "BertLMHeadModel"),
  603. ("bert-generation", "BertGenerationDecoder"),
  604. ("big_bird", "BigBirdForCausalLM"),
  605. ("bigbird_pegasus", "BigBirdPegasusForCausalLM"),
  606. ("biogpt", "BioGptForCausalLM"),
  607. ("bitnet", "BitNetForCausalLM"),
  608. ("blenderbot", "BlenderbotForCausalLM"),
  609. ("blenderbot-small", "BlenderbotSmallForCausalLM"),
  610. ("bloom", "BloomForCausalLM"),
  611. ("blt", "BltForCausalLM"),
  612. ("camembert", "CamembertForCausalLM"),
  613. ("code_llama", "LlamaForCausalLM"),
  614. ("codegen", "CodeGenForCausalLM"),
  615. ("cohere", "CohereForCausalLM"),
  616. ("cohere2", "Cohere2ForCausalLM"),
  617. ("cpmant", "CpmAntForCausalLM"),
  618. ("ctrl", "CTRLLMHeadModel"),
  619. ("cwm", "CwmForCausalLM"),
  620. ("data2vec-text", "Data2VecTextForCausalLM"),
  621. ("dbrx", "DbrxForCausalLM"),
  622. ("deepseek_v2", "DeepseekV2ForCausalLM"),
  623. ("deepseek_v3", "DeepseekV3ForCausalLM"),
  624. ("diffllama", "DiffLlamaForCausalLM"),
  625. ("doge", "DogeForCausalLM"),
  626. ("dots1", "Dots1ForCausalLM"),
  627. ("electra", "ElectraForCausalLM"),
  628. ("emu3", "Emu3ForCausalLM"),
  629. ("ernie", "ErnieForCausalLM"),
  630. ("ernie4_5", "Ernie4_5ForCausalLM"),
  631. ("ernie4_5_moe", "Ernie4_5_MoeForCausalLM"),
  632. ("exaone4", "Exaone4ForCausalLM"),
  633. ("exaone_moe", "ExaoneMoeForCausalLM"),
  634. ("falcon", "FalconForCausalLM"),
  635. ("falcon_h1", "FalconH1ForCausalLM"),
  636. ("falcon_mamba", "FalconMambaForCausalLM"),
  637. ("flex_olmo", "FlexOlmoForCausalLM"),
  638. ("fuyu", "FuyuForCausalLM"),
  639. ("gemma", "GemmaForCausalLM"),
  640. ("gemma2", "Gemma2ForCausalLM"),
  641. ("gemma3", "Gemma3ForConditionalGeneration"),
  642. ("gemma3_text", "Gemma3ForCausalLM"),
  643. ("gemma3n", "Gemma3nForConditionalGeneration"),
  644. ("gemma3n_text", "Gemma3nForCausalLM"),
  645. ("gemma4", "Gemma4ForConditionalGeneration"),
  646. ("gemma4_text", "Gemma4ForCausalLM"),
  647. ("git", "GitForCausalLM"),
  648. ("glm", "GlmForCausalLM"),
  649. ("glm4", "Glm4ForCausalLM"),
  650. ("glm4_moe", "Glm4MoeForCausalLM"),
  651. ("glm4_moe_lite", "Glm4MoeLiteForCausalLM"),
  652. ("glm_moe_dsa", "GlmMoeDsaForCausalLM"),
  653. ("got_ocr2", "GotOcr2ForConditionalGeneration"),
  654. ("gpt-sw3", "GPT2LMHeadModel"),
  655. ("gpt2", "GPT2LMHeadModel"),
  656. ("gpt_bigcode", "GPTBigCodeForCausalLM"),
  657. ("gpt_neo", "GPTNeoForCausalLM"),
  658. ("gpt_neox", "GPTNeoXForCausalLM"),
  659. ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
  660. ("gpt_oss", "GptOssForCausalLM"),
  661. ("gptj", "GPTJForCausalLM"),
  662. ("granite", "GraniteForCausalLM"),
  663. ("granitemoe", "GraniteMoeForCausalLM"),
  664. ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),
  665. ("granitemoeshared", "GraniteMoeSharedForCausalLM"),
  666. ("helium", "HeliumForCausalLM"),
  667. ("hunyuan_v1_dense", "HunYuanDenseV1ForCausalLM"),
  668. ("hunyuan_v1_moe", "HunYuanMoEV1ForCausalLM"),
  669. ("jais2", "Jais2ForCausalLM"),
  670. ("jamba", "JambaForCausalLM"),
  671. ("jetmoe", "JetMoeForCausalLM"),
  672. ("lfm2", "Lfm2ForCausalLM"),
  673. ("lfm2_moe", "Lfm2MoeForCausalLM"),
  674. ("llama", "LlamaForCausalLM"),
  675. ("llama4", "Llama4ForCausalLM"),
  676. ("llama4_text", "Llama4ForCausalLM"),
  677. ("longcat_flash", "LongcatFlashForCausalLM"),
  678. ("mamba", "MambaForCausalLM"),
  679. ("mamba2", "Mamba2ForCausalLM"),
  680. ("marian", "MarianForCausalLM"),
  681. ("mbart", "MBartForCausalLM"),
  682. ("megatron-bert", "MegatronBertForCausalLM"),
  683. ("minimax", "MiniMaxForCausalLM"),
  684. ("minimax_m2", "MiniMaxM2ForCausalLM"),
  685. ("ministral", "MinistralForCausalLM"),
  686. ("ministral3", "Ministral3ForCausalLM"),
  687. ("mistral", "MistralForCausalLM"),
  688. ("mixtral", "MixtralForCausalLM"),
  689. ("mllama", "MllamaForCausalLM"),
  690. ("modernbert-decoder", "ModernBertDecoderForCausalLM"),
  691. ("moshi", "MoshiForCausalLM"),
  692. ("mpt", "MptForCausalLM"),
  693. ("musicgen", "MusicgenForCausalLM"),
  694. ("musicgen_melody", "MusicgenMelodyForCausalLM"),
  695. ("mvp", "MvpForCausalLM"),
  696. ("nanochat", "NanoChatForCausalLM"),
  697. ("nemotron", "NemotronForCausalLM"),
  698. ("nemotron_h", "NemotronHForCausalLM"),
  699. ("olmo", "OlmoForCausalLM"),
  700. ("olmo2", "Olmo2ForCausalLM"),
  701. ("olmo3", "Olmo3ForCausalLM"),
  702. ("olmo_hybrid", "OlmoHybridForCausalLM"),
  703. ("olmoe", "OlmoeForCausalLM"),
  704. ("openai-gpt", "OpenAIGPTLMHeadModel"),
  705. ("opt", "OPTForCausalLM"),
  706. ("pegasus", "PegasusForCausalLM"),
  707. ("persimmon", "PersimmonForCausalLM"),
  708. ("phi", "PhiForCausalLM"),
  709. ("phi3", "Phi3ForCausalLM"),
  710. ("phi4_multimodal", "Phi4MultimodalForCausalLM"),
  711. ("phimoe", "PhimoeForCausalLM"),
  712. ("plbart", "PLBartForCausalLM"),
  713. ("prophetnet", "ProphetNetForCausalLM"),
  714. ("qwen2", "Qwen2ForCausalLM"),
  715. ("qwen2_moe", "Qwen2MoeForCausalLM"),
  716. ("qwen3", "Qwen3ForCausalLM"),
  717. ("qwen3_5", "Qwen3_5ForCausalLM"), # VLM compatibility
  718. ("qwen3_5_moe", "Qwen3_5MoeForCausalLM"), # VLM compatibility
  719. ("qwen3_5_moe_text", "Qwen3_5MoeForCausalLM"),
  720. ("qwen3_5_text", "Qwen3_5ForCausalLM"),
  721. ("qwen3_moe", "Qwen3MoeForCausalLM"),
  722. ("qwen3_next", "Qwen3NextForCausalLM"),
  723. ("recurrent_gemma", "RecurrentGemmaForCausalLM"),
  724. ("reformer", "ReformerModelWithLMHead"),
  725. ("rembert", "RemBertForCausalLM"),
  726. ("roberta", "RobertaForCausalLM"),
  727. ("roberta-prelayernorm", "RobertaPreLayerNormForCausalLM"),
  728. ("roc_bert", "RoCBertForCausalLM"),
  729. ("roformer", "RoFormerForCausalLM"),
  730. ("rwkv", "RwkvForCausalLM"),
  731. ("seed_oss", "SeedOssForCausalLM"),
  732. ("smollm3", "SmolLM3ForCausalLM"),
  733. ("solar_open", "SolarOpenForCausalLM"),
  734. ("stablelm", "StableLmForCausalLM"),
  735. ("starcoder2", "Starcoder2ForCausalLM"),
  736. ("trocr", "TrOCRForCausalLM"),
  737. ("vaultgemma", "VaultGemmaForCausalLM"),
  738. ("whisper", "WhisperForCausalLM"),
  739. ("xglm", "XGLMForCausalLM"),
  740. ("xlm", "XLMWithLMHeadModel"),
  741. ("xlm-roberta", "XLMRobertaForCausalLM"),
  742. ("xlm-roberta-xl", "XLMRobertaXLForCausalLM"),
  743. ("xlnet", "XLNetLMHeadModel"),
  744. ("xlstm", "xLSTMForCausalLM"),
  745. ("xmod", "XmodForCausalLM"),
  746. ("youtu", "YoutuForCausalLM"),
  747. ("zamba", "ZambaForCausalLM"),
  748. ("zamba2", "Zamba2ForCausalLM"),
  749. ]
  750. )
  751. MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict(
  752. [
  753. # Model for Image mapping
  754. ("aimv2_vision_model", "Aimv2VisionModel"),
  755. ("beit", "BeitModel"),
  756. ("bit", "BitModel"),
  757. ("cohere2_vision", "Cohere2VisionModel"),
  758. ("conditional_detr", "ConditionalDetrModel"),
  759. ("convnext", "ConvNextModel"),
  760. ("convnextv2", "ConvNextV2Model"),
  761. ("dab-detr", "DabDetrModel"),
  762. ("data2vec-vision", "Data2VecVisionModel"),
  763. ("deformable_detr", "DeformableDetrModel"),
  764. ("deit", "DeiTModel"),
  765. ("depth_pro", "DepthProModel"),
  766. ("detr", "DetrModel"),
  767. ("dinat", "DinatModel"),
  768. ("dinov2", "Dinov2Model"),
  769. ("dinov2_with_registers", "Dinov2WithRegistersModel"),
  770. ("dinov3_convnext", "DINOv3ConvNextModel"),
  771. ("dinov3_vit", "DINOv3ViTModel"),
  772. ("dpt", "DPTModel"),
  773. ("efficientnet", "EfficientNetModel"),
  774. ("focalnet", "FocalNetModel"),
  775. ("glpn", "GLPNModel"),
  776. ("hiera", "HieraModel"),
  777. ("ijepa", "IJepaModel"),
  778. ("imagegpt", "ImageGPTModel"),
  779. ("levit", "LevitModel"),
  780. ("llama4", "Llama4VisionModel"),
  781. ("mlcd", "MLCDVisionModel"), # Keep this to make some original hub repositories (from `DeepGlint-AI`) works
  782. ("mlcd_vision_model", "MLCDVisionModel"),
  783. ("mllama", "MllamaVisionModel"),
  784. ("mobilenet_v1", "MobileNetV1Model"),
  785. ("mobilenet_v2", "MobileNetV2Model"),
  786. ("mobilevit", "MobileViTModel"),
  787. ("mobilevitv2", "MobileViTV2Model"),
  788. ("pixio", "PixioModel"),
  789. ("poolformer", "PoolFormerModel"),
  790. ("pvt", "PvtModel"),
  791. ("regnet", "RegNetModel"),
  792. ("resnet", "ResNetModel"),
  793. ("segformer", "SegformerModel"),
  794. ("siglip_vision_model", "SiglipVisionModel"),
  795. ("swiftformer", "SwiftFormerModel"),
  796. ("swin", "SwinModel"),
  797. ("swin2sr", "Swin2SRModel"),
  798. ("swinv2", "Swinv2Model"),
  799. ("table-transformer", "TableTransformerModel"),
  800. ("timesformer", "TimesformerModel"),
  801. ("timm_backbone", "TimmBackbone"),
  802. ("timm_wrapper", "TimmWrapperModel"),
  803. ("videomae", "VideoMAEModel"),
  804. ("vit", "ViTModel"),
  805. ("vit_mae", "ViTMAEModel"),
  806. ("vit_msn", "ViTMSNModel"),
  807. ("vitdet", "VitDetModel"),
  808. ("vivit", "VivitModel"),
  809. ("yolos", "YolosModel"),
  810. ]
  811. )
  812. MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
  813. [
  814. ("deit", "DeiTForMaskedImageModeling"),
  815. ("focalnet", "FocalNetForMaskedImageModeling"),
  816. ("swin", "SwinForMaskedImageModeling"),
  817. ("swinv2", "Swinv2ForMaskedImageModeling"),
  818. ("vit", "ViTForMaskedImageModeling"),
  819. ]
  820. )
  821. MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
  822. # Model for Causal Image Modeling mapping
  823. [
  824. ("imagegpt", "ImageGPTForCausalImageModeling"),
  825. ]
  826. )
  827. MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
  828. [
  829. # Model for Image Classification mapping
  830. ("beit", "BeitForImageClassification"),
  831. ("bit", "BitForImageClassification"),
  832. ("clip", "CLIPForImageClassification"),
  833. ("convnext", "ConvNextForImageClassification"),
  834. ("convnextv2", "ConvNextV2ForImageClassification"),
  835. ("cvt", "CvtForImageClassification"),
  836. ("data2vec-vision", "Data2VecVisionForImageClassification"),
  837. (
  838. "deit",
  839. ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher"),
  840. ),
  841. ("dinat", "DinatForImageClassification"),
  842. ("dinov2", "Dinov2ForImageClassification"),
  843. ("dinov2_with_registers", "Dinov2WithRegistersForImageClassification"),
  844. ("donut-swin", "DonutSwinForImageClassification"),
  845. ("efficientnet", "EfficientNetForImageClassification"),
  846. ("focalnet", "FocalNetForImageClassification"),
  847. ("hgnet_v2", "HGNetV2ForImageClassification"),
  848. ("hiera", "HieraForImageClassification"),
  849. ("ijepa", "IJepaForImageClassification"),
  850. ("imagegpt", "ImageGPTForImageClassification"),
  851. (
  852. "levit",
  853. ("LevitForImageClassification", "LevitForImageClassificationWithTeacher"),
  854. ),
  855. ("metaclip_2", "MetaClip2ForImageClassification"),
  856. ("mobilenet_v1", "MobileNetV1ForImageClassification"),
  857. ("mobilenet_v2", "MobileNetV2ForImageClassification"),
  858. ("mobilevit", "MobileViTForImageClassification"),
  859. ("mobilevitv2", "MobileViTV2ForImageClassification"),
  860. (
  861. "perceiver",
  862. (
  863. "PerceiverForImageClassificationLearned",
  864. "PerceiverForImageClassificationFourier",
  865. "PerceiverForImageClassificationConvProcessing",
  866. ),
  867. ),
  868. ("poolformer", "PoolFormerForImageClassification"),
  869. ("pp_lcnet", "PPLCNetForImageClassification"),
  870. ("pvt", "PvtForImageClassification"),
  871. ("pvt_v2", "PvtV2ForImageClassification"),
  872. ("regnet", "RegNetForImageClassification"),
  873. ("resnet", "ResNetForImageClassification"),
  874. ("segformer", "SegformerForImageClassification"),
  875. ("shieldgemma2", "ShieldGemma2ForImageClassification"),
  876. ("siglip", "SiglipForImageClassification"),
  877. ("siglip2", "Siglip2ForImageClassification"),
  878. ("swiftformer", "SwiftFormerForImageClassification"),
  879. ("swin", "SwinForImageClassification"),
  880. ("swinv2", "Swinv2ForImageClassification"),
  881. ("textnet", "TextNetForImageClassification"),
  882. ("timm_wrapper", "TimmWrapperForImageClassification"),
  883. ("vit", "ViTForImageClassification"),
  884. ("vit_msn", "ViTMSNForImageClassification"),
  885. ]
  886. )
  887. MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
  888. [
  889. # Do not add new models here, this class will be deprecated in the future.
  890. # Model for Image Segmentation mapping
  891. ("detr", "DetrForSegmentation"),
  892. ]
  893. )
  894. MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
  895. [
  896. # Model for Semantic Segmentation mapping
  897. ("beit", "BeitForSemanticSegmentation"),
  898. ("data2vec-vision", "Data2VecVisionForSemanticSegmentation"),
  899. ("dpt", "DPTForSemanticSegmentation"),
  900. ("mobilenet_v2", "MobileNetV2ForSemanticSegmentation"),
  901. ("mobilevit", "MobileViTForSemanticSegmentation"),
  902. ("mobilevitv2", "MobileViTV2ForSemanticSegmentation"),
  903. ("segformer", "SegformerForSemanticSegmentation"),
  904. ("upernet", "UperNetForSemanticSegmentation"),
  905. ]
  906. )
  907. MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
  908. [
  909. # Model for Instance Segmentation mapping
  910. # MaskFormerForInstanceSegmentation can be removed from this mapping in v5
  911. ("maskformer", "MaskFormerForInstanceSegmentation"),
  912. ]
  913. )
  914. MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES = OrderedDict(
  915. [
  916. # Model for Universal Segmentation mapping
  917. ("detr", "DetrForSegmentation"),
  918. ("eomt", "EomtForUniversalSegmentation"),
  919. ("eomt_dinov3", "EomtDinov3ForUniversalSegmentation"),
  920. ("mask2former", "Mask2FormerForUniversalSegmentation"),
  921. ("maskformer", "MaskFormerForInstanceSegmentation"),
  922. ("oneformer", "OneFormerForUniversalSegmentation"),
  923. ("videomt", "VideomtForUniversalSegmentation"),
  924. ]
  925. )
  926. MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
  927. [
  928. ("timesformer", "TimesformerForVideoClassification"),
  929. ("videomae", "VideoMAEForVideoClassification"),
  930. ("vivit", "VivitForVideoClassification"),
  931. ("vjepa2", "VJEPA2ForVideoClassification"),
  932. ]
  933. )
  934. MODEL_FOR_RETRIEVAL_MAPPING_NAMES = OrderedDict(
  935. [
  936. ("colmodernvbert", "ColModernVBertForRetrieval"),
  937. ("colpali", "ColPaliForRetrieval"),
  938. ]
  939. )
  940. MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
  941. [
  942. ("aria", "AriaForConditionalGeneration"),
  943. ("aya_vision", "AyaVisionForConditionalGeneration"),
  944. ("blip", "BlipForConditionalGeneration"),
  945. ("blip-2", "Blip2ForConditionalGeneration"),
  946. ("chameleon", "ChameleonForConditionalGeneration"),
  947. ("cohere2_vision", "Cohere2VisionForConditionalGeneration"),
  948. ("deepseek_vl", "DeepseekVLForConditionalGeneration"),
  949. ("deepseek_vl_hybrid", "DeepseekVLHybridForConditionalGeneration"),
  950. ("emu3", "Emu3ForConditionalGeneration"),
  951. ("ernie4_5_vl_moe", "Ernie4_5_VLMoeForConditionalGeneration"),
  952. ("evolla", "EvollaForProteinText2Text"),
  953. ("fast_vlm", "FastVlmForConditionalGeneration"),
  954. ("florence2", "Florence2ForConditionalGeneration"),
  955. ("fuyu", "FuyuForCausalLM"),
  956. ("gemma3", "Gemma3ForConditionalGeneration"),
  957. ("gemma3n", "Gemma3nForConditionalGeneration"),
  958. ("gemma4", "Gemma4ForConditionalGeneration"),
  959. ("git", "GitForCausalLM"),
  960. ("glm46v", "Glm46VForConditionalGeneration"),
  961. ("glm4v", "Glm4vForConditionalGeneration"),
  962. ("glm4v_moe", "Glm4vMoeForConditionalGeneration"),
  963. ("glm_ocr", "GlmOcrForConditionalGeneration"),
  964. ("got_ocr2", "GotOcr2ForConditionalGeneration"),
  965. ("idefics", "IdeficsForVisionText2Text"),
  966. ("idefics2", "Idefics2ForConditionalGeneration"),
  967. ("idefics3", "Idefics3ForConditionalGeneration"),
  968. ("instructblip", "InstructBlipForConditionalGeneration"),
  969. ("instructblipvideo", "InstructBlipVideoForConditionalGeneration"),
  970. ("internvl", "InternVLForConditionalGeneration"),
  971. ("janus", "JanusForConditionalGeneration"),
  972. ("kosmos-2", "Kosmos2ForConditionalGeneration"),
  973. ("kosmos-2.5", "Kosmos2_5ForConditionalGeneration"),
  974. ("lfm2_vl", "Lfm2VlForConditionalGeneration"),
  975. ("lighton_ocr", "LightOnOcrForConditionalGeneration"),
  976. ("llama4", "Llama4ForConditionalGeneration"),
  977. ("llava", "LlavaForConditionalGeneration"),
  978. ("llava_next", "LlavaNextForConditionalGeneration"),
  979. ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
  980. ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
  981. ("mistral3", "Mistral3ForConditionalGeneration"),
  982. ("mistral4", "Mistral4ForCausalLM"),
  983. ("mllama", "MllamaForConditionalGeneration"),
  984. ("ovis2", "Ovis2ForConditionalGeneration"),
  985. ("paddleocr_vl", "PaddleOCRVLForConditionalGeneration"),
  986. ("paligemma", "PaliGemmaForConditionalGeneration"),
  987. ("perception_lm", "PerceptionLMForConditionalGeneration"),
  988. ("pi0", "PI0ForConditionalGeneration"),
  989. ("pix2struct", "Pix2StructForConditionalGeneration"),
  990. ("pixtral", "LlavaForConditionalGeneration"),
  991. ("pp_chart2table", "GotOcr2ForConditionalGeneration"),
  992. ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),
  993. ("qwen2_vl", "Qwen2VLForConditionalGeneration"),
  994. ("qwen3_5", "Qwen3_5ForConditionalGeneration"),
  995. ("qwen3_5_moe", "Qwen3_5MoeForConditionalGeneration"),
  996. ("qwen3_vl", "Qwen3VLForConditionalGeneration"),
  997. ("qwen3_vl_moe", "Qwen3VLMoeForConditionalGeneration"),
  998. ("shieldgemma2", "Gemma3ForConditionalGeneration"),
  999. ("smolvlm", "SmolVLMForConditionalGeneration"),
  1000. ("t5gemma2", "T5Gemma2ForConditionalGeneration"),
  1001. ("udop", "UdopForConditionalGeneration"),
  1002. ("video_llama_3", "VideoLlama3ForConditionalGeneration"),
  1003. ("video_llava", "VideoLlavaForConditionalGeneration"),
  1004. ("vipllava", "VipLlavaForConditionalGeneration"),
  1005. ("vision-encoder-decoder", "VisionEncoderDecoderModel"),
  1006. ]
  1007. )
  1008. # Models that accept text and optionally multimodal data in inputs
  1009. # and can generate text and optionally multimodal data.
  1010. MODEL_FOR_MULTIMODAL_LM_MAPPING_NAMES = OrderedDict(
  1011. [
  1012. *list(MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.items()),
  1013. ("glmasr", "GlmAsrForConditionalGeneration"),
  1014. ("granite_speech", "GraniteSpeechForConditionalGeneration"),
  1015. ("kyutai_speech_to_text", "KyutaiSpeechToTextForConditionalGeneration"),
  1016. ("phi4_multimodal", "Phi4MultimodalForCausalLM"),
  1017. ("qwen2_5_omni", "Qwen2_5OmniForConditionalGeneration"),
  1018. ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),
  1019. ("qwen3_omni_moe", "Qwen3OmniMoeForConditionalGeneration"),
  1020. ("vibevoice_asr", "VibeVoiceAsrForConditionalGeneration"),
  1021. ("voxtral", "VoxtralForConditionalGeneration"),
  1022. ("voxtral_realtime", "VoxtralRealtimeForConditionalGeneration"),
  1023. ]
  1024. )
  1025. MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
  1026. [
  1027. # Model for Masked LM mapping
  1028. ("albert", "AlbertForMaskedLM"),
  1029. ("bart", "BartForConditionalGeneration"),
  1030. ("bert", "BertForMaskedLM"),
  1031. ("big_bird", "BigBirdForMaskedLM"),
  1032. ("camembert", "CamembertForMaskedLM"),
  1033. ("convbert", "ConvBertForMaskedLM"),
  1034. ("data2vec-text", "Data2VecTextForMaskedLM"),
  1035. ("deberta", "DebertaForMaskedLM"),
  1036. ("deberta-v2", "DebertaV2ForMaskedLM"),
  1037. ("distilbert", "DistilBertForMaskedLM"),
  1038. ("electra", "ElectraForMaskedLM"),
  1039. ("ernie", "ErnieForMaskedLM"),
  1040. ("esm", "EsmForMaskedLM"),
  1041. ("eurobert", "EuroBertForMaskedLM"),
  1042. ("flaubert", "FlaubertWithLMHeadModel"),
  1043. ("fnet", "FNetForMaskedLM"),
  1044. ("funnel", "FunnelForMaskedLM"),
  1045. ("ibert", "IBertForMaskedLM"),
  1046. ("jina_embeddings_v3", "JinaEmbeddingsV3ForMaskedLM"),
  1047. ("layoutlm", "LayoutLMForMaskedLM"),
  1048. ("longformer", "LongformerForMaskedLM"),
  1049. ("luke", "LukeForMaskedLM"),
  1050. ("mbart", "MBartForConditionalGeneration"),
  1051. ("megatron-bert", "MegatronBertForMaskedLM"),
  1052. ("mobilebert", "MobileBertForMaskedLM"),
  1053. ("modernbert", "ModernBertForMaskedLM"),
  1054. ("modernvbert", "ModernVBertForMaskedLM"),
  1055. ("mpnet", "MPNetForMaskedLM"),
  1056. ("mra", "MraForMaskedLM"),
  1057. ("mvp", "MvpForConditionalGeneration"),
  1058. ("nomic_bert", "NomicBertForMaskedLM"),
  1059. ("nystromformer", "NystromformerForMaskedLM"),
  1060. ("perceiver", "PerceiverForMaskedLM"),
  1061. ("reformer", "ReformerForMaskedLM"),
  1062. ("rembert", "RemBertForMaskedLM"),
  1063. ("roberta", "RobertaForMaskedLM"),
  1064. ("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"),
  1065. ("roc_bert", "RoCBertForMaskedLM"),
  1066. ("roformer", "RoFormerForMaskedLM"),
  1067. ("squeezebert", "SqueezeBertForMaskedLM"),
  1068. ("tapas", "TapasForMaskedLM"),
  1069. ("xlm", "XLMWithLMHeadModel"),
  1070. ("xlm-roberta", "XLMRobertaForMaskedLM"),
  1071. ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
  1072. ("xmod", "XmodForMaskedLM"),
  1073. ("yoso", "YosoForMaskedLM"),
  1074. ]
  1075. )
  1076. MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
  1077. [
  1078. # Model for Object Detection mapping
  1079. ("conditional_detr", "ConditionalDetrForObjectDetection"),
  1080. ("d_fine", "DFineForObjectDetection"),
  1081. ("dab-detr", "DabDetrForObjectDetection"),
  1082. ("deformable_detr", "DeformableDetrForObjectDetection"),
  1083. ("detr", "DetrForObjectDetection"),
  1084. ("lw_detr", "LwDetrForObjectDetection"),
  1085. ("pp_doclayout_v2", "PPDocLayoutV2ForObjectDetection"),
  1086. ("pp_doclayout_v3", "PPDocLayoutV3ForObjectDetection"),
  1087. ("pp_ocrv5_mobile_det", "PPOCRV5MobileDetForObjectDetection"),
  1088. ("pp_ocrv5_server_det", "PPOCRV5ServerDetForObjectDetection"),
  1089. ("rt_detr", "RTDetrForObjectDetection"),
  1090. ("rt_detr_v2", "RTDetrV2ForObjectDetection"),
  1091. ("table-transformer", "TableTransformerForObjectDetection"),
  1092. ("yolos", "YolosForObjectDetection"),
  1093. ]
  1094. )
  1095. MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
  1096. [
  1097. # Model for Zero Shot Object Detection mapping
  1098. ("grounding-dino", "GroundingDinoForObjectDetection"),
  1099. ("mm-grounding-dino", "MMGroundingDinoForObjectDetection"),
  1100. ("omdet-turbo", "OmDetTurboForObjectDetection"),
  1101. ("owlv2", "Owlv2ForObjectDetection"),
  1102. ("owlvit", "OwlViTForObjectDetection"),
  1103. ]
  1104. )
  1105. MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = OrderedDict(
  1106. [
  1107. # Model for depth estimation mapping
  1108. ("chmv2", "CHMv2ForDepthEstimation"),
  1109. ("depth_anything", "DepthAnythingForDepthEstimation"),
  1110. ("depth_pro", "DepthProForDepthEstimation"),
  1111. ("dpt", "DPTForDepthEstimation"),
  1112. ("glpn", "GLPNForDepthEstimation"),
  1113. ("prompt_depth_anything", "PromptDepthAnythingForDepthEstimation"),
  1114. ("zoedepth", "ZoeDepthForDepthEstimation"),
  1115. ]
  1116. )
  1117. MODEL_FOR_TEXT_RECOGNITION_MAPPING_NAMES = OrderedDict(
  1118. [
  1119. ("pp_ocrv5_mobile_rec", "PPOCRV5MobileRecForTextRecognition"),
  1120. ("pp_ocrv5_server_rec", "PPOCRV5ServerRecForTextRecognition"),
  1121. ]
  1122. )
  1123. MODEL_FOR_TABLE_RECOGNITION_MAPPING_NAMES = OrderedDict(
  1124. [
  1125. ("slanext", "SLANeXtForTableRecognition"),
  1126. ]
  1127. )
  1128. MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
  1129. [
  1130. # Model for Seq2Seq Causal LM mapping
  1131. ("audioflamingo3", "AudioFlamingo3ForConditionalGeneration"),
  1132. ("bart", "BartForConditionalGeneration"),
  1133. ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
  1134. ("blenderbot", "BlenderbotForConditionalGeneration"),
  1135. ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
  1136. ("encoder-decoder", "EncoderDecoderModel"),
  1137. ("fsmt", "FSMTForConditionalGeneration"),
  1138. ("glmasr", "GlmAsrForConditionalGeneration"),
  1139. ("granite_speech", "GraniteSpeechForConditionalGeneration"),
  1140. ("led", "LEDForConditionalGeneration"),
  1141. ("longt5", "LongT5ForConditionalGeneration"),
  1142. ("m2m_100", "M2M100ForConditionalGeneration"),
  1143. ("marian", "MarianMTModel"),
  1144. ("mbart", "MBartForConditionalGeneration"),
  1145. ("mt5", "MT5ForConditionalGeneration"),
  1146. ("musicflamingo", "MusicFlamingoForConditionalGeneration"),
  1147. ("mvp", "MvpForConditionalGeneration"),
  1148. ("nllb-moe", "NllbMoeForConditionalGeneration"),
  1149. ("pegasus", "PegasusForConditionalGeneration"),
  1150. ("pegasus_x", "PegasusXForConditionalGeneration"),
  1151. ("plbart", "PLBartForConditionalGeneration"),
  1152. ("prophetnet", "ProphetNetForConditionalGeneration"),
  1153. ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),
  1154. ("seamless_m4t", "SeamlessM4TForTextToText"),
  1155. ("seamless_m4t_v2", "SeamlessM4Tv2ForTextToText"),
  1156. ("switch_transformers", "SwitchTransformersForConditionalGeneration"),
  1157. ("t5", "T5ForConditionalGeneration"),
  1158. ("t5gemma", "T5GemmaForConditionalGeneration"),
  1159. ("t5gemma2", "T5Gemma2ForConditionalGeneration"),
  1160. ("umt5", "UMT5ForConditionalGeneration"),
  1161. ("vibevoice_asr", "VibeVoiceAsrForConditionalGeneration"),
  1162. ("voxtral", "VoxtralForConditionalGeneration"),
  1163. ("voxtral_realtime", "VoxtralRealtimeForConditionalGeneration"),
  1164. ]
  1165. )
  1166. MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
  1167. [
  1168. ("cohere_asr", "CohereAsrForConditionalGeneration"),
  1169. ("dia", "DiaForConditionalGeneration"),
  1170. ("granite_speech", "GraniteSpeechForConditionalGeneration"),
  1171. ("kyutai_speech_to_text", "KyutaiSpeechToTextForConditionalGeneration"),
  1172. ("moonshine", "MoonshineForConditionalGeneration"),
  1173. ("moonshine_streaming", "MoonshineStreamingForConditionalGeneration"),
  1174. ("pop2piano", "Pop2PianoForConditionalGeneration"),
  1175. ("seamless_m4t", "SeamlessM4TForSpeechToText"),
  1176. ("seamless_m4t_v2", "SeamlessM4Tv2ForSpeechToText"),
  1177. ("speech-encoder-decoder", "SpeechEncoderDecoderModel"),
  1178. ("speech_to_text", "Speech2TextForConditionalGeneration"),
  1179. ("speecht5", "SpeechT5ForSpeechToText"),
  1180. ("vibevoice_asr", "VibeVoiceAsrForConditionalGeneration"),
  1181. ("voxtral", "VoxtralForConditionalGeneration"),
  1182. ("voxtral_realtime", "VoxtralRealtimeForConditionalGeneration"),
  1183. ("whisper", "WhisperForConditionalGeneration"),
  1184. ]
  1185. )
  1186. MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
  1187. [
  1188. # Model for Sequence Classification mapping
  1189. ("albert", "AlbertForSequenceClassification"),
  1190. ("arcee", "ArceeForSequenceClassification"),
  1191. ("bart", "BartForSequenceClassification"),
  1192. ("bert", "BertForSequenceClassification"),
  1193. ("big_bird", "BigBirdForSequenceClassification"),
  1194. ("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"),
  1195. ("biogpt", "BioGptForSequenceClassification"),
  1196. ("bloom", "BloomForSequenceClassification"),
  1197. ("camembert", "CamembertForSequenceClassification"),
  1198. ("canine", "CanineForSequenceClassification"),
  1199. ("code_llama", "LlamaForSequenceClassification"),
  1200. ("convbert", "ConvBertForSequenceClassification"),
  1201. ("ctrl", "CTRLForSequenceClassification"),
  1202. ("data2vec-text", "Data2VecTextForSequenceClassification"),
  1203. ("deberta", "DebertaForSequenceClassification"),
  1204. ("deberta-v2", "DebertaV2ForSequenceClassification"),
  1205. ("deepseek_v2", "DeepseekV2ForSequenceClassification"),
  1206. ("deepseek_v3", "DeepseekV3ForSequenceClassification"),
  1207. ("diffllama", "DiffLlamaForSequenceClassification"),
  1208. ("distilbert", "DistilBertForSequenceClassification"),
  1209. ("doge", "DogeForSequenceClassification"),
  1210. ("electra", "ElectraForSequenceClassification"),
  1211. ("ernie", "ErnieForSequenceClassification"),
  1212. ("esm", "EsmForSequenceClassification"),
  1213. ("eurobert", "EuroBertForSequenceClassification"),
  1214. ("exaone4", "Exaone4ForSequenceClassification"),
  1215. ("falcon", "FalconForSequenceClassification"),
  1216. ("flaubert", "FlaubertForSequenceClassification"),
  1217. ("fnet", "FNetForSequenceClassification"),
  1218. ("funnel", "FunnelForSequenceClassification"),
  1219. ("gemma", "GemmaForSequenceClassification"),
  1220. ("gemma2", "Gemma2ForSequenceClassification"),
  1221. ("gemma3", "Gemma3ForSequenceClassification"),
  1222. ("gemma3_text", "Gemma3TextForSequenceClassification"),
  1223. ("glm", "GlmForSequenceClassification"),
  1224. ("glm4", "Glm4ForSequenceClassification"),
  1225. ("gpt-sw3", "GPT2ForSequenceClassification"),
  1226. ("gpt2", "GPT2ForSequenceClassification"),
  1227. ("gpt_bigcode", "GPTBigCodeForSequenceClassification"),
  1228. ("gpt_neo", "GPTNeoForSequenceClassification"),
  1229. ("gpt_neox", "GPTNeoXForSequenceClassification"),
  1230. ("gpt_oss", "GptOssForSequenceClassification"),
  1231. ("gptj", "GPTJForSequenceClassification"),
  1232. ("helium", "HeliumForSequenceClassification"),
  1233. ("hunyuan_v1_dense", "HunYuanDenseV1ForSequenceClassification"),
  1234. ("hunyuan_v1_moe", "HunYuanMoEV1ForSequenceClassification"),
  1235. ("ibert", "IBertForSequenceClassification"),
  1236. ("jamba", "JambaForSequenceClassification"),
  1237. ("jetmoe", "JetMoeForSequenceClassification"),
  1238. ("jina_embeddings_v3", "JinaEmbeddingsV3ForSequenceClassification"),
  1239. ("layoutlm", "LayoutLMForSequenceClassification"),
  1240. ("layoutlmv2", "LayoutLMv2ForSequenceClassification"),
  1241. ("layoutlmv3", "LayoutLMv3ForSequenceClassification"),
  1242. ("lilt", "LiltForSequenceClassification"),
  1243. ("llama", "LlamaForSequenceClassification"),
  1244. ("longformer", "LongformerForSequenceClassification"),
  1245. ("luke", "LukeForSequenceClassification"),
  1246. ("markuplm", "MarkupLMForSequenceClassification"),
  1247. ("mbart", "MBartForSequenceClassification"),
  1248. ("megatron-bert", "MegatronBertForSequenceClassification"),
  1249. ("minimax", "MiniMaxForSequenceClassification"),
  1250. ("ministral", "MinistralForSequenceClassification"),
  1251. ("ministral3", "Ministral3ForSequenceClassification"),
  1252. ("mistral", "MistralForSequenceClassification"),
  1253. ("mistral4", "Mistral4ForSequenceClassification"),
  1254. ("mixtral", "MixtralForSequenceClassification"),
  1255. ("mobilebert", "MobileBertForSequenceClassification"),
  1256. ("modernbert", "ModernBertForSequenceClassification"),
  1257. ("modernbert-decoder", "ModernBertDecoderForSequenceClassification"),
  1258. ("modernvbert", "ModernVBertForSequenceClassification"),
  1259. ("mpnet", "MPNetForSequenceClassification"),
  1260. ("mpt", "MptForSequenceClassification"),
  1261. ("mra", "MraForSequenceClassification"),
  1262. ("mt5", "MT5ForSequenceClassification"),
  1263. ("mvp", "MvpForSequenceClassification"),
  1264. ("nemotron", "NemotronForSequenceClassification"),
  1265. ("nomic_bert", "NomicBertForSequenceClassification"),
  1266. ("nystromformer", "NystromformerForSequenceClassification"),
  1267. ("openai-gpt", "OpenAIGPTForSequenceClassification"),
  1268. ("opt", "OPTForSequenceClassification"),
  1269. ("perceiver", "PerceiverForSequenceClassification"),
  1270. ("persimmon", "PersimmonForSequenceClassification"),
  1271. ("phi", "PhiForSequenceClassification"),
  1272. ("phi3", "Phi3ForSequenceClassification"),
  1273. ("phimoe", "PhimoeForSequenceClassification"),
  1274. ("plbart", "PLBartForSequenceClassification"),
  1275. ("qwen2", "Qwen2ForSequenceClassification"),
  1276. ("qwen2_moe", "Qwen2MoeForSequenceClassification"),
  1277. ("qwen3", "Qwen3ForSequenceClassification"),
  1278. ("qwen3_5", "Qwen3_5ForSequenceClassification"),
  1279. ("qwen3_5_text", "Qwen3_5ForSequenceClassification"),
  1280. ("qwen3_moe", "Qwen3MoeForSequenceClassification"),
  1281. ("qwen3_next", "Qwen3NextForSequenceClassification"),
  1282. ("reformer", "ReformerForSequenceClassification"),
  1283. ("rembert", "RemBertForSequenceClassification"),
  1284. ("roberta", "RobertaForSequenceClassification"),
  1285. ("roberta-prelayernorm", "RobertaPreLayerNormForSequenceClassification"),
  1286. ("roc_bert", "RoCBertForSequenceClassification"),
  1287. ("roformer", "RoFormerForSequenceClassification"),
  1288. ("seed_oss", "SeedOssForSequenceClassification"),
  1289. ("smollm3", "SmolLM3ForSequenceClassification"),
  1290. ("squeezebert", "SqueezeBertForSequenceClassification"),
  1291. ("stablelm", "StableLmForSequenceClassification"),
  1292. ("starcoder2", "Starcoder2ForSequenceClassification"),
  1293. ("t5", "T5ForSequenceClassification"),
  1294. ("t5gemma", "T5GemmaForSequenceClassification"),
  1295. ("t5gemma2", "T5Gemma2ForSequenceClassification"),
  1296. ("tapas", "TapasForSequenceClassification"),
  1297. ("umt5", "UMT5ForSequenceClassification"),
  1298. ("xlm", "XLMForSequenceClassification"),
  1299. ("xlm-roberta", "XLMRobertaForSequenceClassification"),
  1300. ("xlm-roberta-xl", "XLMRobertaXLForSequenceClassification"),
  1301. ("xlnet", "XLNetForSequenceClassification"),
  1302. ("xmod", "XmodForSequenceClassification"),
  1303. ("yoso", "YosoForSequenceClassification"),
  1304. ("zamba", "ZambaForSequenceClassification"),
  1305. ("zamba2", "Zamba2ForSequenceClassification"),
  1306. ]
  1307. )
  1308. MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
  1309. [
  1310. # Model for Question Answering mapping
  1311. ("albert", "AlbertForQuestionAnswering"),
  1312. ("arcee", "ArceeForQuestionAnswering"),
  1313. ("bart", "BartForQuestionAnswering"),
  1314. ("bert", "BertForQuestionAnswering"),
  1315. ("big_bird", "BigBirdForQuestionAnswering"),
  1316. ("bigbird_pegasus", "BigBirdPegasusForQuestionAnswering"),
  1317. ("bloom", "BloomForQuestionAnswering"),
  1318. ("camembert", "CamembertForQuestionAnswering"),
  1319. ("canine", "CanineForQuestionAnswering"),
  1320. ("convbert", "ConvBertForQuestionAnswering"),
  1321. ("data2vec-text", "Data2VecTextForQuestionAnswering"),
  1322. ("deberta", "DebertaForQuestionAnswering"),
  1323. ("deberta-v2", "DebertaV2ForQuestionAnswering"),
  1324. ("diffllama", "DiffLlamaForQuestionAnswering"),
  1325. ("distilbert", "DistilBertForQuestionAnswering"),
  1326. ("electra", "ElectraForQuestionAnswering"),
  1327. ("ernie", "ErnieForQuestionAnswering"),
  1328. ("exaone4", "Exaone4ForQuestionAnswering"),
  1329. ("falcon", "FalconForQuestionAnswering"),
  1330. ("flaubert", "FlaubertForQuestionAnsweringSimple"),
  1331. ("fnet", "FNetForQuestionAnswering"),
  1332. ("funnel", "FunnelForQuestionAnswering"),
  1333. ("gpt2", "GPT2ForQuestionAnswering"),
  1334. ("gpt_neo", "GPTNeoForQuestionAnswering"),
  1335. ("gpt_neox", "GPTNeoXForQuestionAnswering"),
  1336. ("gptj", "GPTJForQuestionAnswering"),
  1337. ("ibert", "IBertForQuestionAnswering"),
  1338. ("jina_embeddings_v3", "JinaEmbeddingsV3ForQuestionAnswering"),
  1339. ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
  1340. ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
  1341. ("led", "LEDForQuestionAnswering"),
  1342. ("lilt", "LiltForQuestionAnswering"),
  1343. ("llama", "LlamaForQuestionAnswering"),
  1344. ("longformer", "LongformerForQuestionAnswering"),
  1345. ("luke", "LukeForQuestionAnswering"),
  1346. ("lxmert", "LxmertForQuestionAnswering"),
  1347. ("markuplm", "MarkupLMForQuestionAnswering"),
  1348. ("mbart", "MBartForQuestionAnswering"),
  1349. ("megatron-bert", "MegatronBertForQuestionAnswering"),
  1350. ("minimax", "MiniMaxForQuestionAnswering"),
  1351. ("ministral", "MinistralForQuestionAnswering"),
  1352. ("ministral3", "Ministral3ForQuestionAnswering"),
  1353. ("mistral", "MistralForQuestionAnswering"),
  1354. ("mixtral", "MixtralForQuestionAnswering"),
  1355. ("mobilebert", "MobileBertForQuestionAnswering"),
  1356. ("modernbert", "ModernBertForQuestionAnswering"),
  1357. ("mpnet", "MPNetForQuestionAnswering"),
  1358. ("mpt", "MptForQuestionAnswering"),
  1359. ("mra", "MraForQuestionAnswering"),
  1360. ("mt5", "MT5ForQuestionAnswering"),
  1361. ("mvp", "MvpForQuestionAnswering"),
  1362. ("nemotron", "NemotronForQuestionAnswering"),
  1363. ("nystromformer", "NystromformerForQuestionAnswering"),
  1364. ("opt", "OPTForQuestionAnswering"),
  1365. ("qwen2", "Qwen2ForQuestionAnswering"),
  1366. ("qwen2_moe", "Qwen2MoeForQuestionAnswering"),
  1367. ("qwen3", "Qwen3ForQuestionAnswering"),
  1368. ("qwen3_moe", "Qwen3MoeForQuestionAnswering"),
  1369. ("qwen3_next", "Qwen3NextForQuestionAnswering"),
  1370. ("reformer", "ReformerForQuestionAnswering"),
  1371. ("rembert", "RemBertForQuestionAnswering"),
  1372. ("roberta", "RobertaForQuestionAnswering"),
  1373. ("roberta-prelayernorm", "RobertaPreLayerNormForQuestionAnswering"),
  1374. ("roc_bert", "RoCBertForQuestionAnswering"),
  1375. ("roformer", "RoFormerForQuestionAnswering"),
  1376. ("seed_oss", "SeedOssForQuestionAnswering"),
  1377. ("smollm3", "SmolLM3ForQuestionAnswering"),
  1378. ("splinter", "SplinterForQuestionAnswering"),
  1379. ("squeezebert", "SqueezeBertForQuestionAnswering"),
  1380. ("t5", "T5ForQuestionAnswering"),
  1381. ("umt5", "UMT5ForQuestionAnswering"),
  1382. ("xlm", "XLMForQuestionAnsweringSimple"),
  1383. ("xlm-roberta", "XLMRobertaForQuestionAnswering"),
  1384. ("xlm-roberta-xl", "XLMRobertaXLForQuestionAnswering"),
  1385. ("xlnet", "XLNetForQuestionAnsweringSimple"),
  1386. ("xmod", "XmodForQuestionAnswering"),
  1387. ("yoso", "YosoForQuestionAnswering"),
  1388. ]
  1389. )
  1390. MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
  1391. [
  1392. # Model for Table Question Answering mapping
  1393. ("tapas", "TapasForQuestionAnswering"),
  1394. ]
  1395. )
  1396. MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
  1397. [
  1398. ("blip", "BlipForQuestionAnswering"),
  1399. ("blip-2", "Blip2ForConditionalGeneration"),
  1400. ("vilt", "ViltForQuestionAnswering"),
  1401. ]
  1402. )
  1403. MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
  1404. [
  1405. ("layoutlm", "LayoutLMForQuestionAnswering"),
  1406. ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
  1407. ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
  1408. ]
  1409. )
  1410. MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
  1411. [
  1412. # Model for Token Classification mapping
  1413. ("albert", "AlbertForTokenClassification"),
  1414. ("apertus", "ApertusForTokenClassification"),
  1415. ("arcee", "ArceeForTokenClassification"),
  1416. ("bert", "BertForTokenClassification"),
  1417. ("big_bird", "BigBirdForTokenClassification"),
  1418. ("biogpt", "BioGptForTokenClassification"),
  1419. ("bloom", "BloomForTokenClassification"),
  1420. ("bros", "BrosForTokenClassification"),
  1421. ("camembert", "CamembertForTokenClassification"),
  1422. ("canine", "CanineForTokenClassification"),
  1423. ("convbert", "ConvBertForTokenClassification"),
  1424. ("data2vec-text", "Data2VecTextForTokenClassification"),
  1425. ("deberta", "DebertaForTokenClassification"),
  1426. ("deberta-v2", "DebertaV2ForTokenClassification"),
  1427. ("deepseek_v3", "DeepseekV3ForTokenClassification"),
  1428. ("diffllama", "DiffLlamaForTokenClassification"),
  1429. ("distilbert", "DistilBertForTokenClassification"),
  1430. ("electra", "ElectraForTokenClassification"),
  1431. ("ernie", "ErnieForTokenClassification"),
  1432. ("esm", "EsmForTokenClassification"),
  1433. ("eurobert", "EuroBertForTokenClassification"),
  1434. ("exaone4", "Exaone4ForTokenClassification"),
  1435. ("falcon", "FalconForTokenClassification"),
  1436. ("flaubert", "FlaubertForTokenClassification"),
  1437. ("fnet", "FNetForTokenClassification"),
  1438. ("funnel", "FunnelForTokenClassification"),
  1439. ("gemma", "GemmaForTokenClassification"),
  1440. ("gemma2", "Gemma2ForTokenClassification"),
  1441. ("glm", "GlmForTokenClassification"),
  1442. ("glm4", "Glm4ForTokenClassification"),
  1443. ("gpt-sw3", "GPT2ForTokenClassification"),
  1444. ("gpt2", "GPT2ForTokenClassification"),
  1445. ("gpt_bigcode", "GPTBigCodeForTokenClassification"),
  1446. ("gpt_neo", "GPTNeoForTokenClassification"),
  1447. ("gpt_neox", "GPTNeoXForTokenClassification"),
  1448. ("gpt_oss", "GptOssForTokenClassification"),
  1449. ("helium", "HeliumForTokenClassification"),
  1450. ("ibert", "IBertForTokenClassification"),
  1451. ("jina_embeddings_v3", "JinaEmbeddingsV3ForTokenClassification"),
  1452. ("layoutlm", "LayoutLMForTokenClassification"),
  1453. ("layoutlmv2", "LayoutLMv2ForTokenClassification"),
  1454. ("layoutlmv3", "LayoutLMv3ForTokenClassification"),
  1455. ("lilt", "LiltForTokenClassification"),
  1456. ("llama", "LlamaForTokenClassification"),
  1457. ("longformer", "LongformerForTokenClassification"),
  1458. ("luke", "LukeForTokenClassification"),
  1459. ("markuplm", "MarkupLMForTokenClassification"),
  1460. ("megatron-bert", "MegatronBertForTokenClassification"),
  1461. ("minimax", "MiniMaxForTokenClassification"),
  1462. ("ministral", "MinistralForTokenClassification"),
  1463. ("ministral3", "Ministral3ForTokenClassification"),
  1464. ("mistral", "MistralForTokenClassification"),
  1465. ("mistral4", "Mistral4ForTokenClassification"),
  1466. ("mixtral", "MixtralForTokenClassification"),
  1467. ("mobilebert", "MobileBertForTokenClassification"),
  1468. ("modernbert", "ModernBertForTokenClassification"),
  1469. ("modernvbert", "ModernVBertForTokenClassification"),
  1470. ("mpnet", "MPNetForTokenClassification"),
  1471. ("mpt", "MptForTokenClassification"),
  1472. ("mra", "MraForTokenClassification"),
  1473. ("mt5", "MT5ForTokenClassification"),
  1474. ("nemotron", "NemotronForTokenClassification"),
  1475. ("nomic_bert", "NomicBertForTokenClassification"),
  1476. ("nystromformer", "NystromformerForTokenClassification"),
  1477. ("persimmon", "PersimmonForTokenClassification"),
  1478. ("phi", "PhiForTokenClassification"),
  1479. ("phi3", "Phi3ForTokenClassification"),
  1480. ("qwen2", "Qwen2ForTokenClassification"),
  1481. ("qwen2_moe", "Qwen2MoeForTokenClassification"),
  1482. ("qwen3", "Qwen3ForTokenClassification"),
  1483. ("qwen3_moe", "Qwen3MoeForTokenClassification"),
  1484. ("qwen3_next", "Qwen3NextForTokenClassification"),
  1485. ("rembert", "RemBertForTokenClassification"),
  1486. ("roberta", "RobertaForTokenClassification"),
  1487. ("roberta-prelayernorm", "RobertaPreLayerNormForTokenClassification"),
  1488. ("roc_bert", "RoCBertForTokenClassification"),
  1489. ("roformer", "RoFormerForTokenClassification"),
  1490. ("seed_oss", "SeedOssForTokenClassification"),
  1491. ("smollm3", "SmolLM3ForTokenClassification"),
  1492. ("squeezebert", "SqueezeBertForTokenClassification"),
  1493. ("stablelm", "StableLmForTokenClassification"),
  1494. ("starcoder2", "Starcoder2ForTokenClassification"),
  1495. ("t5", "T5ForTokenClassification"),
  1496. ("t5gemma", "T5GemmaForTokenClassification"),
  1497. ("t5gemma2", "T5Gemma2ForTokenClassification"),
  1498. ("umt5", "UMT5ForTokenClassification"),
  1499. ("xlm", "XLMForTokenClassification"),
  1500. ("xlm-roberta", "XLMRobertaForTokenClassification"),
  1501. ("xlm-roberta-xl", "XLMRobertaXLForTokenClassification"),
  1502. ("xlnet", "XLNetForTokenClassification"),
  1503. ("xmod", "XmodForTokenClassification"),
  1504. ("yoso", "YosoForTokenClassification"),
  1505. ]
  1506. )
  1507. MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
  1508. [
  1509. # Model for Multiple Choice mapping
  1510. ("albert", "AlbertForMultipleChoice"),
  1511. ("bert", "BertForMultipleChoice"),
  1512. ("big_bird", "BigBirdForMultipleChoice"),
  1513. ("camembert", "CamembertForMultipleChoice"),
  1514. ("canine", "CanineForMultipleChoice"),
  1515. ("convbert", "ConvBertForMultipleChoice"),
  1516. ("data2vec-text", "Data2VecTextForMultipleChoice"),
  1517. ("deberta-v2", "DebertaV2ForMultipleChoice"),
  1518. ("distilbert", "DistilBertForMultipleChoice"),
  1519. ("electra", "ElectraForMultipleChoice"),
  1520. ("ernie", "ErnieForMultipleChoice"),
  1521. ("flaubert", "FlaubertForMultipleChoice"),
  1522. ("fnet", "FNetForMultipleChoice"),
  1523. ("funnel", "FunnelForMultipleChoice"),
  1524. ("ibert", "IBertForMultipleChoice"),
  1525. ("longformer", "LongformerForMultipleChoice"),
  1526. ("luke", "LukeForMultipleChoice"),
  1527. ("megatron-bert", "MegatronBertForMultipleChoice"),
  1528. ("mobilebert", "MobileBertForMultipleChoice"),
  1529. ("modernbert", "ModernBertForMultipleChoice"),
  1530. ("mpnet", "MPNetForMultipleChoice"),
  1531. ("mra", "MraForMultipleChoice"),
  1532. ("nystromformer", "NystromformerForMultipleChoice"),
  1533. ("rembert", "RemBertForMultipleChoice"),
  1534. ("roberta", "RobertaForMultipleChoice"),
  1535. ("roberta-prelayernorm", "RobertaPreLayerNormForMultipleChoice"),
  1536. ("roc_bert", "RoCBertForMultipleChoice"),
  1537. ("roformer", "RoFormerForMultipleChoice"),
  1538. ("squeezebert", "SqueezeBertForMultipleChoice"),
  1539. ("xlm", "XLMForMultipleChoice"),
  1540. ("xlm-roberta", "XLMRobertaForMultipleChoice"),
  1541. ("xlm-roberta-xl", "XLMRobertaXLForMultipleChoice"),
  1542. ("xlnet", "XLNetForMultipleChoice"),
  1543. ("xmod", "XmodForMultipleChoice"),
  1544. ("yoso", "YosoForMultipleChoice"),
  1545. ]
  1546. )
  1547. MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
  1548. [
  1549. ("bert", "BertForNextSentencePrediction"),
  1550. ("ernie", "ErnieForNextSentencePrediction"),
  1551. ("fnet", "FNetForNextSentencePrediction"),
  1552. ("megatron-bert", "MegatronBertForNextSentencePrediction"),
  1553. ("mobilebert", "MobileBertForNextSentencePrediction"),
  1554. ]
  1555. )
  1556. MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
  1557. [
  1558. # Model for Audio Classification mapping
  1559. ("audio-spectrogram-transformer", "ASTForAudioClassification"),
  1560. ("data2vec-audio", "Data2VecAudioForSequenceClassification"),
  1561. ("hubert", "HubertForSequenceClassification"),
  1562. ("sew", "SEWForSequenceClassification"),
  1563. ("sew-d", "SEWDForSequenceClassification"),
  1564. ("unispeech", "UniSpeechForSequenceClassification"),
  1565. ("unispeech-sat", "UniSpeechSatForSequenceClassification"),
  1566. ("wav2vec2", "Wav2Vec2ForSequenceClassification"),
  1567. ("wav2vec2-bert", "Wav2Vec2BertForSequenceClassification"),
  1568. ("wav2vec2-conformer", "Wav2Vec2ConformerForSequenceClassification"),
  1569. ("wavlm", "WavLMForSequenceClassification"),
  1570. ("whisper", "WhisperForAudioClassification"),
  1571. ]
  1572. )
  1573. MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict(
  1574. [
  1575. # Model for Connectionist temporal classification (CTC) mapping
  1576. ("data2vec-audio", "Data2VecAudioForCTC"),
  1577. ("hubert", "HubertForCTC"),
  1578. ("lasr_ctc", "LasrForCTC"),
  1579. ("parakeet_ctc", "ParakeetForCTC"),
  1580. ("sew", "SEWForCTC"),
  1581. ("sew-d", "SEWDForCTC"),
  1582. ("unispeech", "UniSpeechForCTC"),
  1583. ("unispeech-sat", "UniSpeechSatForCTC"),
  1584. ("wav2vec2", "Wav2Vec2ForCTC"),
  1585. ("wav2vec2-bert", "Wav2Vec2BertForCTC"),
  1586. ("wav2vec2-conformer", "Wav2Vec2ConformerForCTC"),
  1587. ("wavlm", "WavLMForCTC"),
  1588. ]
  1589. )
  1590. MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
  1591. [
  1592. # Model for Audio Classification mapping
  1593. ("data2vec-audio", "Data2VecAudioForAudioFrameClassification"),
  1594. ("unispeech-sat", "UniSpeechSatForAudioFrameClassification"),
  1595. ("wav2vec2", "Wav2Vec2ForAudioFrameClassification"),
  1596. ("wav2vec2-bert", "Wav2Vec2BertForAudioFrameClassification"),
  1597. ("wav2vec2-conformer", "Wav2Vec2ConformerForAudioFrameClassification"),
  1598. ("wavlm", "WavLMForAudioFrameClassification"),
  1599. ]
  1600. )
  1601. MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict(
  1602. [
  1603. # Model for Audio Classification mapping
  1604. ("data2vec-audio", "Data2VecAudioForXVector"),
  1605. ("unispeech-sat", "UniSpeechSatForXVector"),
  1606. ("wav2vec2", "Wav2Vec2ForXVector"),
  1607. ("wav2vec2-bert", "Wav2Vec2BertForXVector"),
  1608. ("wav2vec2-conformer", "Wav2Vec2ConformerForXVector"),
  1609. ("wavlm", "WavLMForXVector"),
  1610. ]
  1611. )
  1612. MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = OrderedDict(
  1613. [
  1614. # Model for Text-To-Spectrogram mapping
  1615. ("fastspeech2_conformer", "FastSpeech2ConformerModel"),
  1616. ("speecht5", "SpeechT5ForTextToSpeech"),
  1617. ]
  1618. )
  1619. MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = OrderedDict(
  1620. [
  1621. # Model for Text-To-Waveform mapping
  1622. ("bark", "BarkModel"),
  1623. ("csm", "CsmForConditionalGeneration"),
  1624. ("fastspeech2_conformer", "FastSpeech2ConformerWithHifiGan"),
  1625. ("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"),
  1626. ("higgs_audio_v2", "HiggsAudioV2ForConditionalGeneration"),
  1627. ("musicgen", "MusicgenForConditionalGeneration"),
  1628. ("musicgen_melody", "MusicgenMelodyForConditionalGeneration"),
  1629. ("qwen2_5_omni", "Qwen2_5OmniForConditionalGeneration"),
  1630. ("qwen3_omni_moe", "Qwen3OmniMoeForConditionalGeneration"),
  1631. ("seamless_m4t", "SeamlessM4TForTextToSpeech"),
  1632. ("seamless_m4t_v2", "SeamlessM4Tv2ForTextToSpeech"),
  1633. ("vits", "VitsModel"),
  1634. ]
  1635. )
  1636. MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
  1637. [
  1638. # Model for Zero Shot Image Classification mapping
  1639. ("align", "AlignModel"),
  1640. ("altclip", "AltCLIPModel"),
  1641. ("blip", "BlipModel"),
  1642. ("blip-2", "Blip2ForImageTextRetrieval"),
  1643. ("chinese_clip", "ChineseCLIPModel"),
  1644. ("clip", "CLIPModel"),
  1645. ("clipseg", "CLIPSegModel"),
  1646. ("metaclip_2", "MetaClip2Model"),
  1647. ("siglip", "SiglipModel"),
  1648. ("siglip2", "Siglip2Model"),
  1649. ]
  1650. )
  1651. MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
  1652. [
  1653. # Backbone mapping
  1654. ("beit", "BeitBackbone"),
  1655. ("bit", "BitBackbone"),
  1656. ("convnext", "ConvNextBackbone"),
  1657. ("convnextv2", "ConvNextV2Backbone"),
  1658. ("dinat", "DinatBackbone"),
  1659. ("dinov2", "Dinov2Backbone"),
  1660. ("dinov2_with_registers", "Dinov2WithRegistersBackbone"),
  1661. ("dinov3_convnext", "DINOv3ConvNextBackbone"),
  1662. ("dinov3_vit", "DINOv3ViTBackbone"),
  1663. ("focalnet", "FocalNetBackbone"),
  1664. ("hgnet_v2", "HGNetV2Backbone"),
  1665. ("hiera", "HieraBackbone"),
  1666. ("lw_detr_vit", "LwDetrViTBackbone"),
  1667. ("maskformer-swin", "MaskFormerSwinBackbone"),
  1668. ("pixio", "PixioBackbone"),
  1669. ("pp_lcnet", "PPLCNetBackbone"),
  1670. ("pp_lcnet_v3", "PPLCNetV3Backbone"),
  1671. ("pvt_v2", "PvtV2Backbone"),
  1672. ("resnet", "ResNetBackbone"),
  1673. ("rt_detr_resnet", "RTDetrResNetBackbone"),
  1674. ("swin", "SwinBackbone"),
  1675. ("swinv2", "Swinv2Backbone"),
  1676. ("textnet", "TextNetBackbone"),
  1677. ("timm_backbone", "TimmBackbone"),
  1678. ("uvdoc_backbone", "UVDocBackbone"),
  1679. ("vitdet", "VitDetBackbone"),
  1680. ("vitpose_backbone", "VitPoseBackbone"),
  1681. ]
  1682. )
  1683. MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
  1684. [
  1685. ("edgetam", "EdgeTamModel"),
  1686. ("edgetam_video", "EdgeTamModel"),
  1687. ("sam", "SamModel"),
  1688. ("sam2", "Sam2Model"),
  1689. ("sam2_video", "Sam2Model"),
  1690. ("sam3_tracker", "Sam3TrackerModel"),
  1691. ("sam3_video", "Sam3TrackerModel"),
  1692. ("sam_hq", "SamHQModel"),
  1693. ]
  1694. )
  1695. MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES = OrderedDict(
  1696. [
  1697. ("superpoint", "SuperPointForKeypointDetection"),
  1698. ]
  1699. )
  1700. MODEL_FOR_KEYPOINT_MATCHING_MAPPING_NAMES = OrderedDict(
  1701. [
  1702. ("efficientloftr", "EfficientLoFTRForKeypointMatching"),
  1703. ("lightglue", "LightGlueForKeypointMatching"),
  1704. ("superglue", "SuperGlueForKeypointMatching"),
  1705. ]
  1706. )
  1707. MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
  1708. [
  1709. ("albert", "AlbertModel"),
  1710. ("bert", "BertModel"),
  1711. ("big_bird", "BigBirdModel"),
  1712. ("clip_text_model", "CLIPTextModel"),
  1713. ("data2vec-text", "Data2VecTextModel"),
  1714. ("deberta", "DebertaModel"),
  1715. ("deberta-v2", "DebertaV2Model"),
  1716. ("distilbert", "DistilBertModel"),
  1717. ("electra", "ElectraModel"),
  1718. ("emu3", "Emu3TextModel"),
  1719. ("flaubert", "FlaubertModel"),
  1720. ("ibert", "IBertModel"),
  1721. ("llama4", "Llama4TextModel"),
  1722. ("longformer", "LongformerModel"),
  1723. ("mllama", "MllamaTextModel"),
  1724. ("mobilebert", "MobileBertModel"),
  1725. ("mt5", "MT5EncoderModel"),
  1726. ("nystromformer", "NystromformerModel"),
  1727. ("reformer", "ReformerModel"),
  1728. ("rembert", "RemBertModel"),
  1729. ("roberta", "RobertaModel"),
  1730. ("roberta-prelayernorm", "RobertaPreLayerNormModel"),
  1731. ("roc_bert", "RoCBertModel"),
  1732. ("roformer", "RoFormerModel"),
  1733. ("squeezebert", "SqueezeBertModel"),
  1734. ("t5", "T5EncoderModel"),
  1735. ("t5gemma", "T5GemmaEncoderModel"),
  1736. ("umt5", "UMT5EncoderModel"),
  1737. ("xlm", "XLMModel"),
  1738. ("xlm-roberta", "XLMRobertaModel"),
  1739. ("xlm-roberta-xl", "XLMRobertaXLModel"),
  1740. ]
  1741. )
  1742. MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
  1743. [
  1744. ("patchtsmixer", "PatchTSMixerForTimeSeriesClassification"),
  1745. ("patchtst", "PatchTSTForClassification"),
  1746. ]
  1747. )
  1748. MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES = OrderedDict(
  1749. [
  1750. ("patchtsmixer", "PatchTSMixerForRegression"),
  1751. ("patchtst", "PatchTSTForRegression"),
  1752. ]
  1753. )
  1754. MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING_NAMES = OrderedDict(
  1755. [
  1756. ("timesfm", "TimesFmModelForPrediction"),
  1757. ("timesfm2_5", "TimesFm2_5ModelForPrediction"),
  1758. ]
  1759. )
  1760. MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = OrderedDict(
  1761. [
  1762. ("swin2sr", "Swin2SRForImageSuperResolution"),
  1763. ]
  1764. )
  1765. MODEL_FOR_AUDIO_TOKENIZATION_NAMES = OrderedDict(
  1766. [
  1767. ("dac", "DacModel"),
  1768. ("higgs_audio_v2_tokenizer", "HiggsAudioV2TokenizerModel"),
  1769. ("vibevoice_acoustic_tokenizer", "VibeVoiceAcousticTokenizerModel"),
  1770. ]
  1771. )
  1772. MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
  1773. MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
  1774. MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
  1775. MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
  1776. CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES
  1777. )
  1778. MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  1779. CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
  1780. )
  1781. MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  1782. CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
  1783. )
  1784. MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping(
  1785. CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES
  1786. )
  1787. MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
  1788. CONFIG_MAPPING_NAMES, MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
  1789. )
  1790. MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = _LazyAutoMapping(
  1791. CONFIG_MAPPING_NAMES, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES
  1792. )
  1793. MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING = _LazyAutoMapping(
  1794. CONFIG_MAPPING_NAMES, MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES
  1795. )
  1796. MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  1797. CONFIG_MAPPING_NAMES, MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES
  1798. )
  1799. MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING = _LazyAutoMapping(
  1800. CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
  1801. )
  1802. MODEL_FOR_MULTIMODAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIMODAL_LM_MAPPING_NAMES)
  1803. MODEL_FOR_RETRIEVAL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_RETRIEVAL_MAPPING_NAMES)
  1804. MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
  1805. CONFIG_MAPPING_NAMES, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES
  1806. )
  1807. MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
  1808. CONFIG_MAPPING_NAMES, MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES
  1809. )
  1810. MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
  1811. MODEL_FOR_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_MAPPING_NAMES)
  1812. MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
  1813. CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
  1814. )
  1815. MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES)
  1816. MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(
  1817. CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES
  1818. )
  1819. MODEL_FOR_DEPTH_ESTIMATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES)
  1820. MODEL_FOR_TEXT_RECOGNITION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_RECOGNITION_MAPPING_NAMES)
  1821. MODEL_FOR_TABLE_RECOGNITION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_RECOGNITION_MAPPING_NAMES)
  1822. MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
  1823. CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
  1824. )
  1825. MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  1826. CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
  1827. )
  1828. MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
  1829. CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
  1830. )
  1831. MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
  1832. CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
  1833. )
  1834. MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  1835. CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
  1836. )
  1837. MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES)
  1838. MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
  1839. CONFIG_MAPPING_NAMES, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
  1840. )
  1841. MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  1842. CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
  1843. )
  1844. MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES)
  1845. MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES)
  1846. MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  1847. CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES
  1848. )
  1849. MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES)
  1850. MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING = _LazyAutoMapping(
  1851. CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES
  1852. )
  1853. MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES)
  1854. MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BACKBONE_MAPPING_NAMES)
  1855. MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES)
  1856. MODEL_FOR_KEYPOINT_DETECTION_MAPPING = _LazyAutoMapping(
  1857. CONFIG_MAPPING_NAMES, MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES
  1858. )
  1859. MODEL_FOR_KEYPOINT_MATCHING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_KEYPOINT_MATCHING_MAPPING_NAMES)
  1860. MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES)
  1861. MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  1862. CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING_NAMES
  1863. )
  1864. MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING = _LazyAutoMapping(
  1865. CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES
  1866. )
  1867. MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING = _LazyAutoMapping(
  1868. CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING_NAMES
  1869. )
  1870. MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES)
  1871. MODEL_FOR_AUDIO_TOKENIZATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_TOKENIZATION_NAMES)
  1872. class AutoModelForMaskGeneration(_BaseAutoModelClass):
  1873. _model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING
  1874. class AutoModelForKeypointDetection(_BaseAutoModelClass):
  1875. _model_mapping = MODEL_FOR_KEYPOINT_DETECTION_MAPPING
  1876. class AutoModelForKeypointMatching(_BaseAutoModelClass):
  1877. _model_mapping = MODEL_FOR_KEYPOINT_MATCHING_MAPPING
  1878. class AutoModelForTextEncoding(_BaseAutoModelClass):
  1879. _model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING
  1880. class AutoModelForImageToImage(_BaseAutoModelClass):
  1881. _model_mapping = MODEL_FOR_IMAGE_TO_IMAGE_MAPPING
  1882. class AutoModel(_BaseAutoModelClass):
  1883. _model_mapping = MODEL_MAPPING
  1884. AutoModel = auto_class_update(AutoModel)
  1885. class AutoModelForPreTraining(_BaseAutoModelClass):
  1886. _model_mapping = MODEL_FOR_PRETRAINING_MAPPING
  1887. AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining")
  1888. class AutoModelForCausalLM(_BaseAutoModelClass):
  1889. _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
  1890. # override to give better return typehint
  1891. @classmethod
  1892. def from_pretrained(
  1893. cls: type["AutoModelForCausalLM"],
  1894. pretrained_model_name_or_path: str | os.PathLike[str],
  1895. *model_args,
  1896. **kwargs,
  1897. ) -> "_BaseModelWithGenerate":
  1898. return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
  1899. AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")
  1900. class AutoModelForMaskedLM(_BaseAutoModelClass):
  1901. _model_mapping = MODEL_FOR_MASKED_LM_MAPPING
  1902. AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling")
  1903. class AutoModelForSeq2SeqLM(_BaseAutoModelClass):
  1904. _model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
  1905. AutoModelForSeq2SeqLM = auto_class_update(
  1906. AutoModelForSeq2SeqLM,
  1907. head_doc="sequence-to-sequence language modeling",
  1908. checkpoint_for_example="google-t5/t5-base",
  1909. )
  1910. class AutoModelForSequenceClassification(_BaseAutoModelClass):
  1911. _model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
  1912. AutoModelForSequenceClassification = auto_class_update(
  1913. AutoModelForSequenceClassification, head_doc="sequence classification"
  1914. )
  1915. class AutoModelForQuestionAnswering(_BaseAutoModelClass):
  1916. _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
  1917. AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering")
  1918. class AutoModelForTableQuestionAnswering(_BaseAutoModelClass):
  1919. _model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
  1920. AutoModelForTableQuestionAnswering = auto_class_update(
  1921. AutoModelForTableQuestionAnswering,
  1922. head_doc="table question answering",
  1923. checkpoint_for_example="google/tapas-base-finetuned-wtq",
  1924. )
  1925. class AutoModelForVisualQuestionAnswering(_BaseAutoModelClass):
  1926. _model_mapping = MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING
  1927. AutoModelForVisualQuestionAnswering = auto_class_update(
  1928. AutoModelForVisualQuestionAnswering,
  1929. head_doc="visual question answering",
  1930. checkpoint_for_example="dandelin/vilt-b32-finetuned-vqa",
  1931. )
  1932. class AutoModelForDocumentQuestionAnswering(_BaseAutoModelClass):
  1933. _model_mapping = MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING
  1934. AutoModelForDocumentQuestionAnswering = auto_class_update(
  1935. AutoModelForDocumentQuestionAnswering,
  1936. head_doc="document question answering",
  1937. checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3',
  1938. )
  1939. class AutoModelForTokenClassification(_BaseAutoModelClass):
  1940. _model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
  1941. AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification")
  1942. class AutoModelForMultipleChoice(_BaseAutoModelClass):
  1943. _model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING
  1944. AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice")
  1945. class AutoModelForNextSentencePrediction(_BaseAutoModelClass):
  1946. _model_mapping = MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
  1947. AutoModelForNextSentencePrediction = auto_class_update(
  1948. AutoModelForNextSentencePrediction, head_doc="next sentence prediction"
  1949. )
  1950. class AutoModelForImageClassification(_BaseAutoModelClass):
  1951. _model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
  1952. AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification")
  1953. class AutoModelForZeroShotImageClassification(_BaseAutoModelClass):
  1954. _model_mapping = MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
  1955. AutoModelForZeroShotImageClassification = auto_class_update(
  1956. AutoModelForZeroShotImageClassification, head_doc="zero-shot image classification"
  1957. )
  1958. class AutoModelForImageSegmentation(_BaseAutoModelClass):
  1959. _model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING
  1960. AutoModelForImageSegmentation = auto_class_update(AutoModelForImageSegmentation, head_doc="image segmentation")
  1961. class AutoModelForSemanticSegmentation(_BaseAutoModelClass):
  1962. _model_mapping = MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
  1963. AutoModelForSemanticSegmentation = auto_class_update(
  1964. AutoModelForSemanticSegmentation, head_doc="semantic segmentation"
  1965. )
  1966. class AutoModelForTimeSeriesPrediction(_BaseAutoModelClass):
  1967. _model_mapping = MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING
  1968. AutoModelForTimeSeriesPrediction = auto_class_update(
  1969. AutoModelForTimeSeriesPrediction, head_doc="time-series prediction"
  1970. )
  1971. class AutoModelForUniversalSegmentation(_BaseAutoModelClass):
  1972. _model_mapping = MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING
  1973. AutoModelForUniversalSegmentation = auto_class_update(
  1974. AutoModelForUniversalSegmentation, head_doc="universal image segmentation"
  1975. )
  1976. class AutoModelForInstanceSegmentation(_BaseAutoModelClass):
  1977. _model_mapping = MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING
  1978. AutoModelForInstanceSegmentation = auto_class_update(
  1979. AutoModelForInstanceSegmentation, head_doc="instance segmentation"
  1980. )
  1981. class AutoModelForObjectDetection(_BaseAutoModelClass):
  1982. _model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING
  1983. AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection")
  1984. class AutoModelForZeroShotObjectDetection(_BaseAutoModelClass):
  1985. _model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
  1986. AutoModelForZeroShotObjectDetection = auto_class_update(
  1987. AutoModelForZeroShotObjectDetection, head_doc="zero-shot object detection"
  1988. )
  1989. class AutoModelForDepthEstimation(_BaseAutoModelClass):
  1990. _model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING
  1991. AutoModelForDepthEstimation = auto_class_update(AutoModelForDepthEstimation, head_doc="depth estimation")
  1992. class AutoModelForTextRecognition(_BaseAutoModelClass):
  1993. _model_mapping = MODEL_FOR_TEXT_RECOGNITION_MAPPING
  1994. AutoModelForTextRecognition = auto_class_update(AutoModelForTextRecognition, head_doc="text recognition")
  1995. class AutoModelForTableRecognition(_BaseAutoModelClass):
  1996. _model_mapping = MODEL_FOR_TABLE_RECOGNITION_MAPPING
  1997. AutoModelForTableRecognition = auto_class_update(AutoModelForTableRecognition, head_doc="table recognition")
  1998. class AutoModelForVideoClassification(_BaseAutoModelClass):
  1999. _model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING
  2000. AutoModelForVideoClassification = auto_class_update(AutoModelForVideoClassification, head_doc="video classification")
  2001. class AutoModelForImageTextToText(_BaseAutoModelClass):
  2002. _model_mapping = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING
  2003. # override to give better return typehint
  2004. @classmethod
  2005. def from_pretrained(
  2006. cls: type["AutoModelForImageTextToText"],
  2007. pretrained_model_name_or_path: str | os.PathLike[str],
  2008. *model_args,
  2009. **kwargs,
  2010. ) -> "_BaseModelWithGenerate":
  2011. return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
  2012. AutoModelForImageTextToText = auto_class_update(AutoModelForImageTextToText, head_doc="image-text-to-text modeling")
  2013. class AutoModelForMultimodalLM(_BaseAutoModelClass):
  2014. _model_mapping = MODEL_FOR_MULTIMODAL_LM_MAPPING
  2015. AutoModelForMultimodalLM = auto_class_update(AutoModelForMultimodalLM, head_doc="multimodal generation")
  2016. class AutoModelForAudioClassification(_BaseAutoModelClass):
  2017. _model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
  2018. AutoModelForAudioClassification = auto_class_update(AutoModelForAudioClassification, head_doc="audio classification")
  2019. class AutoModelForCTC(_BaseAutoModelClass):
  2020. _model_mapping = MODEL_FOR_CTC_MAPPING
  2021. AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification")
  2022. class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
  2023. _model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
  2024. AutoModelForSpeechSeq2Seq = auto_class_update(
  2025. AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
  2026. )
  2027. class AutoModelForAudioFrameClassification(_BaseAutoModelClass):
  2028. _model_mapping = MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING
  2029. AutoModelForAudioFrameClassification = auto_class_update(
  2030. AutoModelForAudioFrameClassification, head_doc="audio frame (token) classification"
  2031. )
  2032. class AutoModelForAudioXVector(_BaseAutoModelClass):
  2033. _model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING
  2034. class AutoModelForTextToSpectrogram(_BaseAutoModelClass):
  2035. _model_mapping = MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING
  2036. class AutoModelForTextToWaveform(_BaseAutoModelClass):
  2037. _model_mapping = MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING
  2038. class AutoBackbone(_BaseAutoBackboneClass):
  2039. _model_mapping = MODEL_FOR_BACKBONE_MAPPING
  2040. AutoModelForAudioXVector = auto_class_update(AutoModelForAudioXVector, head_doc="audio retrieval via x-vector")
  2041. class AutoModelForMaskedImageModeling(_BaseAutoModelClass):
  2042. _model_mapping = MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING
  2043. AutoModelForMaskedImageModeling = auto_class_update(AutoModelForMaskedImageModeling, head_doc="masked image modeling")
  2044. class AutoModelForAudioTokenization(_BaseAutoModelClass):
  2045. _model_mapping = MODEL_FOR_AUDIO_TOKENIZATION_MAPPING
  2046. AutoModelForAudioTokenization = auto_class_update(
  2047. AutoModelForAudioTokenization, head_doc="audio tokenization through codebooks"
  2048. )
  2049. __all__ = [
  2050. "MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
  2051. "MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING",
  2052. "MODEL_FOR_AUDIO_TOKENIZATION_MAPPING",
  2053. "MODEL_FOR_AUDIO_XVECTOR_MAPPING",
  2054. "MODEL_FOR_BACKBONE_MAPPING",
  2055. "MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING",
  2056. "MODEL_FOR_CAUSAL_LM_MAPPING",
  2057. "MODEL_FOR_CTC_MAPPING",
  2058. "MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING",
  2059. "MODEL_FOR_DEPTH_ESTIMATION_MAPPING",
  2060. "MODEL_FOR_TEXT_RECOGNITION_MAPPING",
  2061. "MODEL_FOR_TABLE_RECOGNITION_MAPPING",
  2062. "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
  2063. "MODEL_FOR_IMAGE_MAPPING",
  2064. "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING",
  2065. "MODEL_FOR_IMAGE_TO_IMAGE_MAPPING",
  2066. "MODEL_FOR_KEYPOINT_DETECTION_MAPPING",
  2067. "MODEL_FOR_KEYPOINT_MATCHING_MAPPING",
  2068. "MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING",
  2069. "MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
  2070. "MODEL_FOR_MASKED_LM_MAPPING",
  2071. "MODEL_FOR_MASK_GENERATION_MAPPING",
  2072. "MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
  2073. "MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
  2074. "MODEL_FOR_OBJECT_DETECTION_MAPPING",
  2075. "MODEL_FOR_PRETRAINING_MAPPING",
  2076. "MODEL_FOR_QUESTION_ANSWERING_MAPPING",
  2077. "MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING",
  2078. "MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
  2079. "MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
  2080. "MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
  2081. "MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
  2082. "MODEL_FOR_TEXT_ENCODING_MAPPING",
  2083. "MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING",
  2084. "MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING",
  2085. "MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING",
  2086. "MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
  2087. "MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING",
  2088. "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING",
  2089. "MODEL_FOR_RETRIEVAL_MAPPING",
  2090. "MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING",
  2091. "MODEL_FOR_MULTIMODAL_LM_MAPPING",
  2092. "MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
  2093. "MODEL_MAPPING",
  2094. "MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING",
  2095. "MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING",
  2096. "MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING",
  2097. "MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING",
  2098. "AutoModel",
  2099. "AutoBackbone",
  2100. "AutoModelForAudioClassification",
  2101. "AutoModelForAudioFrameClassification",
  2102. "AutoModelForAudioTokenization",
  2103. "AutoModelForAudioXVector",
  2104. "AutoModelForCausalLM",
  2105. "AutoModelForCTC",
  2106. "AutoModelForDepthEstimation",
  2107. "AutoModelForTextRecognition",
  2108. "AutoModelForTableRecognition",
  2109. "AutoModelForImageClassification",
  2110. "AutoModelForImageSegmentation",
  2111. "AutoModelForImageToImage",
  2112. "AutoModelForInstanceSegmentation",
  2113. "AutoModelForKeypointDetection",
  2114. "AutoModelForKeypointMatching",
  2115. "AutoModelForMaskGeneration",
  2116. "AutoModelForTextEncoding",
  2117. "AutoModelForMaskedImageModeling",
  2118. "AutoModelForMaskedLM",
  2119. "AutoModelForMultipleChoice",
  2120. "AutoModelForMultimodalLM",
  2121. "AutoModelForNextSentencePrediction",
  2122. "AutoModelForObjectDetection",
  2123. "AutoModelForPreTraining",
  2124. "AutoModelForQuestionAnswering",
  2125. "AutoModelForSemanticSegmentation",
  2126. "AutoModelForSeq2SeqLM",
  2127. "AutoModelForSequenceClassification",
  2128. "AutoModelForSpeechSeq2Seq",
  2129. "AutoModelForTableQuestionAnswering",
  2130. "AutoModelForTextToSpectrogram",
  2131. "AutoModelForTextToWaveform",
  2132. "AutoModelForTimeSeriesPrediction",
  2133. "AutoModelForTokenClassification",
  2134. "AutoModelForUniversalSegmentation",
  2135. "AutoModelForVideoClassification",
  2136. "AutoModelForVisualQuestionAnswering",
  2137. "AutoModelForDocumentQuestionAnswering",
  2138. "AutoModelForZeroShotImageClassification",
  2139. "AutoModelForZeroShotObjectDetection",
  2140. "AutoModelForImageTextToText",
  2141. ]