| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376 |
- # Copyright 2018 The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Auto Model class."""
- import os
- from collections import OrderedDict
- from typing import TYPE_CHECKING
- from ...utils import logging
- from .auto_factory import (
- _BaseAutoBackboneClass,
- _BaseAutoModelClass,
- _LazyAutoMapping,
- auto_class_update,
- )
- from .configuration_auto import CONFIG_MAPPING_NAMES
- if TYPE_CHECKING:
- from ...generation import GenerationMixin
- from ...modeling_utils import PreTrainedModel
- # class for better type annotations
- class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
- pass
- logger = logging.get_logger(__name__)
- MODEL_MAPPING_NAMES = OrderedDict(
- [
- # Base model mapping
- ("afmoe", "AfmoeModel"),
- ("aimv2", "Aimv2Model"),
- ("aimv2_vision_model", "Aimv2VisionModel"),
- ("albert", "AlbertModel"),
- ("align", "AlignModel"),
- ("altclip", "AltCLIPModel"),
- ("apertus", "ApertusModel"),
- ("arcee", "ArceeModel"),
- ("aria", "AriaModel"),
- ("aria_text", "AriaTextModel"),
- ("audio-spectrogram-transformer", "ASTModel"),
- ("audioflamingo3", "AudioFlamingo3ForConditionalGeneration"),
- ("audioflamingo3_encoder", "AudioFlamingo3Encoder"),
- ("autoformer", "AutoformerModel"),
- ("aya_vision", "AyaVisionModel"),
- ("bamba", "BambaModel"),
- ("bark", "BarkModel"),
- ("bart", "BartModel"),
- ("beit", "BeitModel"),
- ("bert", "BertModel"),
- ("bert-generation", "BertGenerationEncoder"),
- ("big_bird", "BigBirdModel"),
- ("bigbird_pegasus", "BigBirdPegasusModel"),
- ("biogpt", "BioGptModel"),
- ("bit", "BitModel"),
- ("bitnet", "BitNetModel"),
- ("blenderbot", "BlenderbotModel"),
- ("blenderbot-small", "BlenderbotSmallModel"),
- ("blip", "BlipModel"),
- ("blip-2", "Blip2Model"),
- ("blip_2_qformer", "Blip2QFormerModel"),
- ("bloom", "BloomModel"),
- ("blt", "BltModel"),
- ("bridgetower", "BridgeTowerModel"),
- ("bros", "BrosModel"),
- ("camembert", "CamembertModel"),
- ("canine", "CanineModel"),
- ("chameleon", "ChameleonModel"),
- ("chinese_clip", "ChineseCLIPModel"),
- ("chinese_clip_vision_model", "ChineseCLIPVisionModel"),
- ("clap", "ClapModel"),
- ("clip", "CLIPModel"),
- ("clip_text_model", "CLIPTextModel"),
- ("clip_vision_model", "CLIPVisionModel"),
- ("clipseg", "CLIPSegModel"),
- ("clvp", "ClvpModelForConditionalGeneration"),
- ("code_llama", "LlamaModel"),
- ("codegen", "CodeGenModel"),
- ("cohere", "CohereModel"),
- ("cohere2", "Cohere2Model"),
- ("cohere2_vision", "Cohere2VisionModel"),
- ("cohere_asr", "CohereAsrModel"),
- ("conditional_detr", "ConditionalDetrModel"),
- ("convbert", "ConvBertModel"),
- ("convnext", "ConvNextModel"),
- ("convnextv2", "ConvNextV2Model"),
- ("cpmant", "CpmAntModel"),
- ("csm", "CsmForConditionalGeneration"),
- ("ctrl", "CTRLModel"),
- ("cvt", "CvtModel"),
- ("cwm", "CwmModel"),
- ("d_fine", "DFineModel"),
- ("dab-detr", "DabDetrModel"),
- ("dac", "DacModel"),
- ("data2vec-audio", "Data2VecAudioModel"),
- ("data2vec-text", "Data2VecTextModel"),
- ("data2vec-vision", "Data2VecVisionModel"),
- ("dbrx", "DbrxModel"),
- ("deberta", "DebertaModel"),
- ("deberta-v2", "DebertaV2Model"),
- ("decision_transformer", "DecisionTransformerModel"),
- ("deepseek_v2", "DeepseekV2Model"),
- ("deepseek_v3", "DeepseekV3Model"),
- ("deepseek_vl", "DeepseekVLModel"),
- ("deepseek_vl_hybrid", "DeepseekVLHybridModel"),
- ("deformable_detr", "DeformableDetrModel"),
- ("deit", "DeiTModel"),
- ("depth_pro", "DepthProModel"),
- ("detr", "DetrModel"),
- ("dia", "DiaModel"),
- ("diffllama", "DiffLlamaModel"),
- ("dinat", "DinatModel"),
- ("dinov2", "Dinov2Model"),
- ("dinov2_with_registers", "Dinov2WithRegistersModel"),
- ("dinov3_convnext", "DINOv3ConvNextModel"),
- ("dinov3_vit", "DINOv3ViTModel"),
- ("distilbert", "DistilBertModel"),
- ("doge", "DogeModel"),
- ("donut-swin", "DonutSwinModel"),
- ("dots1", "Dots1Model"),
- ("dpr", "DPRQuestionEncoder"),
- ("dpt", "DPTModel"),
- ("edgetam", "EdgeTamModel"),
- ("edgetam_video", "EdgeTamVideoModel"),
- ("edgetam_vision_model", "EdgeTamVisionModel"),
- ("efficientloftr", "EfficientLoFTRModel"),
- ("efficientnet", "EfficientNetModel"),
- ("electra", "ElectraModel"),
- ("emu3", "Emu3Model"),
- ("encodec", "EncodecModel"),
- ("ernie", "ErnieModel"),
- ("ernie4_5", "Ernie4_5Model"),
- ("ernie4_5_moe", "Ernie4_5_MoeModel"),
- ("ernie4_5_vl_moe", "Ernie4_5_VLMoeModel"),
- ("esm", "EsmModel"),
- ("eurobert", "EuroBertModel"),
- ("evolla", "EvollaModel"),
- ("exaone4", "Exaone4Model"),
- ("exaone_moe", "ExaoneMoeModel"),
- ("falcon", "FalconModel"),
- ("falcon_h1", "FalconH1Model"),
- ("falcon_mamba", "FalconMambaModel"),
- ("fast_vlm", "FastVlmModel"),
- ("fastspeech2_conformer", "FastSpeech2ConformerModel"),
- ("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"),
- ("flaubert", "FlaubertModel"),
- ("flava", "FlavaModel"),
- ("flex_olmo", "FlexOlmoModel"),
- ("florence2", "Florence2Model"),
- ("fnet", "FNetModel"),
- ("focalnet", "FocalNetModel"),
- ("fsmt", "FSMTModel"),
- ("funnel", ("FunnelModel", "FunnelBaseModel")),
- ("fuyu", "FuyuModel"),
- ("gemma", "GemmaModel"),
- ("gemma2", "Gemma2Model"),
- ("gemma3", "Gemma3Model"),
- ("gemma3_text", "Gemma3TextModel"),
- ("gemma3n", "Gemma3nModel"),
- ("gemma3n_audio", "Gemma3nAudioEncoder"),
- ("gemma3n_text", "Gemma3nTextModel"),
- ("gemma3n_vision", "TimmWrapperModel"),
- ("gemma4", "Gemma4Model"),
- ("gemma4_audio", "Gemma4AudioModel"),
- ("gemma4_text", "Gemma4TextModel"),
- ("gemma4_vision", "Gemma4VisionModel"),
- ("git", "GitModel"),
- ("glm", "GlmModel"),
- ("glm4", "Glm4Model"),
- ("glm46v", "Glm46VModel"),
- ("glm4_moe", "Glm4MoeModel"),
- ("glm4_moe_lite", "Glm4MoeLiteModel"),
- ("glm4v", "Glm4vModel"),
- ("glm4v_moe", "Glm4vMoeModel"),
- ("glm4v_moe_text", "Glm4vMoeTextModel"),
- ("glm4v_moe_vision", "Glm4vMoeVisionModel"),
- ("glm4v_text", "Glm4vTextModel"),
- ("glm4v_vision", "Glm4vVisionModel"),
- ("glm_image", "GlmImageModel"),
- ("glm_image_text", "GlmImageTextModel"),
- ("glm_image_vision", "GlmImageVisionModel"),
- ("glm_image_vqmodel", "GlmImageVQVAE"),
- ("glm_moe_dsa", "GlmMoeDsaModel"),
- ("glm_ocr", "GlmOcrModel"),
- ("glm_ocr_text", "GlmOcrTextModel"),
- ("glm_ocr_vision", "GlmOcrVisionModel"),
- ("glmasr", "GlmAsrForConditionalGeneration"),
- ("glmasr_encoder", "GlmAsrEncoder"),
- ("glpn", "GLPNModel"),
- ("got_ocr2", "GotOcr2Model"),
- ("gpt-sw3", "GPT2Model"),
- ("gpt2", "GPT2Model"),
- ("gpt_bigcode", "GPTBigCodeModel"),
- ("gpt_neo", "GPTNeoModel"),
- ("gpt_neox", "GPTNeoXModel"),
- ("gpt_neox_japanese", "GPTNeoXJapaneseModel"),
- ("gpt_oss", "GptOssModel"),
- ("gptj", "GPTJModel"),
- ("granite", "GraniteModel"),
- ("granitemoe", "GraniteMoeModel"),
- ("granitemoehybrid", "GraniteMoeHybridModel"),
- ("granitemoeshared", "GraniteMoeSharedModel"),
- ("grounding-dino", "GroundingDinoModel"),
- ("groupvit", "GroupViTModel"),
- ("helium", "HeliumModel"),
- ("hgnet_v2", "HGNetV2Backbone"),
- ("hiera", "HieraModel"),
- ("higgs_audio_v2", "HiggsAudioV2ForConditionalGeneration"),
- ("higgs_audio_v2_tokenizer", "HiggsAudioV2TokenizerModel"),
- ("hubert", "HubertModel"),
- ("hunyuan_v1_dense", "HunYuanDenseV1Model"),
- ("hunyuan_v1_moe", "HunYuanMoEV1Model"),
- ("ibert", "IBertModel"),
- ("idefics", "IdeficsModel"),
- ("idefics2", "Idefics2Model"),
- ("idefics3", "Idefics3Model"),
- ("idefics3_vision", "Idefics3VisionTransformer"),
- ("ijepa", "IJepaModel"),
- ("imagegpt", "ImageGPTModel"),
- ("informer", "InformerModel"),
- ("instructblip", "InstructBlipModel"),
- ("instructblipvideo", "InstructBlipVideoModel"),
- ("internvl", "InternVLModel"),
- ("internvl_vision", "InternVLVisionModel"),
- ("jais2", "Jais2Model"),
- ("jamba", "JambaModel"),
- ("janus", "JanusModel"),
- ("jetmoe", "JetMoeModel"),
- ("jina_embeddings_v3", "JinaEmbeddingsV3Model"),
- ("kosmos-2", "Kosmos2Model"),
- ("kosmos-2.5", "Kosmos2_5Model"),
- ("kyutai_speech_to_text", "KyutaiSpeechToTextModel"),
- ("lasr_ctc", "LasrForCTC"),
- ("lasr_encoder", "LasrEncoder"),
- ("layoutlm", "LayoutLMModel"),
- ("layoutlmv2", "LayoutLMv2Model"),
- ("layoutlmv3", "LayoutLMv3Model"),
- ("led", "LEDModel"),
- ("levit", "LevitModel"),
- ("lfm2", "Lfm2Model"),
- ("lfm2_moe", "Lfm2MoeModel"),
- ("lfm2_vl", "Lfm2VlModel"),
- ("lightglue", "LightGlueForKeypointMatching"),
- ("lighton_ocr", "LightOnOcrModel"),
- ("lilt", "LiltModel"),
- ("llama", "LlamaModel"),
- ("llama4", "Llama4ForConditionalGeneration"),
- ("llama4_text", "Llama4TextModel"),
- ("llava", "LlavaModel"),
- ("llava_next", "LlavaNextModel"),
- ("llava_next_video", "LlavaNextVideoModel"),
- ("llava_onevision", "LlavaOnevisionModel"),
- ("longcat_flash", "LongcatFlashModel"),
- ("longformer", "LongformerModel"),
- ("longt5", "LongT5Model"),
- ("luke", "LukeModel"),
- ("lw_detr", "LwDetrModel"),
- ("lxmert", "LxmertModel"),
- ("m2m_100", "M2M100Model"),
- ("mamba", "MambaModel"),
- ("mamba2", "Mamba2Model"),
- ("marian", "MarianModel"),
- ("markuplm", "MarkupLMModel"),
- ("mask2former", "Mask2FormerModel"),
- ("maskformer", "MaskFormerModel"),
- ("maskformer-swin", "MaskFormerSwinModel"),
- ("mbart", "MBartModel"),
- ("megatron-bert", "MegatronBertModel"),
- ("metaclip_2", "MetaClip2Model"),
- ("mgp-str", "MgpstrForSceneTextRecognition"),
- ("mimi", "MimiModel"),
- ("minimax", "MiniMaxModel"),
- ("minimax_m2", "MiniMaxM2Model"),
- ("ministral", "MinistralModel"),
- ("ministral3", "Ministral3Model"),
- ("mistral", "MistralModel"),
- ("mistral3", "Mistral3Model"),
- ("mistral4", "Mistral4Model"),
- ("mixtral", "MixtralModel"),
- ("mlcd", "MLCDVisionModel"), # Keep this to make some original hub repositories (from `DeepGlint-AI`) works
- ("mlcd_vision_model", "MLCDVisionModel"),
- ("mllama", "MllamaModel"),
- ("mm-grounding-dino", "MMGroundingDinoModel"),
- ("mobilebert", "MobileBertModel"),
- ("mobilenet_v1", "MobileNetV1Model"),
- ("mobilenet_v2", "MobileNetV2Model"),
- ("mobilevit", "MobileViTModel"),
- ("mobilevitv2", "MobileViTV2Model"),
- ("modernbert", "ModernBertModel"),
- ("modernbert-decoder", "ModernBertDecoderModel"),
- ("modernvbert", "ModernVBertModel"),
- ("moonshine", "MoonshineModel"),
- ("moonshine_streaming", "MoonshineStreamingModel"),
- ("moshi", "MoshiModel"),
- ("mpnet", "MPNetModel"),
- ("mpt", "MptModel"),
- ("mra", "MraModel"),
- ("mt5", "MT5Model"),
- ("musicflamingo", "MusicFlamingoForConditionalGeneration"),
- ("musicflamingo_encoder", "AudioFlamingo3Encoder"),
- ("musicgen", "MusicgenModel"),
- ("musicgen_melody", "MusicgenMelodyModel"),
- ("mvp", "MvpModel"),
- ("nanochat", "NanoChatModel"),
- ("nemotron", "NemotronModel"),
- ("nemotron_h", "NemotronHModel"),
- ("nllb-moe", "NllbMoeModel"),
- ("nomic_bert", "NomicBertModel"),
- ("nystromformer", "NystromformerModel"),
- ("olmo", "OlmoModel"),
- ("olmo2", "Olmo2Model"),
- ("olmo3", "Olmo3Model"),
- ("olmo_hybrid", "OlmoHybridModel"),
- ("olmoe", "OlmoeModel"),
- ("omdet-turbo", "OmDetTurboForObjectDetection"),
- ("oneformer", "OneFormerModel"),
- ("openai-gpt", "OpenAIGPTModel"),
- ("opt", "OPTModel"),
- ("ovis2", "Ovis2Model"),
- ("owlv2", "Owlv2Model"),
- ("owlvit", "OwlViTModel"),
- ("paligemma", "PaliGemmaModel"),
- ("parakeet_ctc", "ParakeetForCTC"),
- ("parakeet_encoder", "ParakeetEncoder"),
- ("patchtsmixer", "PatchTSMixerModel"),
- ("patchtst", "PatchTSTModel"),
- ("pe_audio", "PeAudioModel"),
- ("pe_audio_encoder", "PeAudioEncoder"),
- ("pe_audio_video", "PeAudioVideoModel"),
- ("pe_audio_video_encoder", "PeAudioVideoEncoder"),
- ("pe_video", "PeVideoModel"),
- ("pe_video_encoder", "PeVideoEncoder"),
- ("pegasus", "PegasusModel"),
- ("pegasus_x", "PegasusXModel"),
- ("perceiver", "PerceiverModel"),
- ("perception_lm", "PerceptionLMModel"),
- ("persimmon", "PersimmonModel"),
- ("phi", "PhiModel"),
- ("phi3", "Phi3Model"),
- ("phi4_multimodal", "Phi4MultimodalModel"),
- ("phimoe", "PhimoeModel"),
- ("pi0", "PI0Model"),
- ("pixio", "PixioModel"),
- ("pixtral", "PixtralVisionModel"),
- ("plbart", "PLBartModel"),
- ("poolformer", "PoolFormerModel"),
- ("pp_doclayout_v3", "PPDocLayoutV3Model"),
- ("pp_ocrv5_mobile_rec", "PPOCRV5MobileRecModel"),
- ("pp_ocrv5_server_rec", "PPOCRV5ServerRecModel"),
- ("prophetnet", "ProphetNetModel"),
- ("pvt", "PvtModel"),
- ("pvt_v2", "PvtV2Model"),
- ("qwen2", "Qwen2Model"),
- ("qwen2_5_vl", "Qwen2_5_VLModel"),
- ("qwen2_5_vl_text", "Qwen2_5_VLTextModel"),
- ("qwen2_audio_encoder", "Qwen2AudioEncoder"),
- ("qwen2_moe", "Qwen2MoeModel"),
- ("qwen2_vl", "Qwen2VLModel"),
- ("qwen2_vl_text", "Qwen2VLTextModel"),
- ("qwen3", "Qwen3Model"),
- ("qwen3_5", "Qwen3_5Model"),
- ("qwen3_5_moe", "Qwen3_5MoeModel"),
- ("qwen3_5_moe_text", "Qwen3_5MoeTextModel"),
- ("qwen3_5_text", "Qwen3_5TextModel"),
- ("qwen3_moe", "Qwen3MoeModel"),
- ("qwen3_next", "Qwen3NextModel"),
- ("qwen3_vl", "Qwen3VLModel"),
- ("qwen3_vl_moe", "Qwen3VLMoeModel"),
- ("qwen3_vl_moe_text", "Qwen3VLMoeTextModel"),
- ("qwen3_vl_text", "Qwen3VLTextModel"),
- ("recurrent_gemma", "RecurrentGemmaModel"),
- ("reformer", "ReformerModel"),
- ("regnet", "RegNetModel"),
- ("rembert", "RemBertModel"),
- ("resnet", "ResNetModel"),
- ("roberta", "RobertaModel"),
- ("roberta-prelayernorm", "RobertaPreLayerNormModel"),
- ("roc_bert", "RoCBertModel"),
- ("roformer", "RoFormerModel"),
- ("rt_detr", "RTDetrModel"),
- ("rt_detr_v2", "RTDetrV2Model"),
- ("rwkv", "RwkvModel"),
- ("sam", "SamModel"),
- ("sam2", "Sam2Model"),
- ("sam2_hiera_det_model", "Sam2HieraDetModel"),
- ("sam2_video", "Sam2VideoModel"),
- ("sam2_vision_model", "Sam2VisionModel"),
- ("sam3", "Sam3Model"),
- ("sam3_tracker", "Sam3TrackerModel"),
- ("sam3_tracker", "Sam3TrackerModel"),
- ("sam3_tracker_video", "Sam3TrackerVideoModel"),
- ("sam3_video", "Sam3VideoModel"),
- ("sam3_vision_model", "Sam3VisionModel"),
- ("sam3_vit_model", "Sam3ViTModel"),
- ("sam_hq", "SamHQModel"),
- ("sam_hq_vision_model", "SamHQVisionModel"),
- ("sam_vision_model", "SamVisionModel"),
- ("seamless_m4t", "SeamlessM4TModel"),
- ("seamless_m4t_v2", "SeamlessM4Tv2Model"),
- ("seed_oss", "SeedOssModel"),
- ("segformer", "SegformerModel"),
- ("seggpt", "SegGptModel"),
- ("sew", "SEWModel"),
- ("sew-d", "SEWDModel"),
- ("siglip", "SiglipModel"),
- ("siglip2", "Siglip2Model"),
- ("siglip2_vision_model", "Siglip2VisionModel"),
- ("siglip_vision_model", "SiglipVisionModel"),
- ("smollm3", "SmolLM3Model"),
- ("smolvlm", "SmolVLMModel"),
- ("smolvlm_vision", "SmolVLMVisionTransformer"),
- ("solar_open", "SolarOpenModel"),
- ("speech_to_text", "Speech2TextModel"),
- ("speecht5", "SpeechT5Model"),
- ("splinter", "SplinterModel"),
- ("squeezebert", "SqueezeBertModel"),
- ("stablelm", "StableLmModel"),
- ("starcoder2", "Starcoder2Model"),
- ("swiftformer", "SwiftFormerModel"),
- ("swin", "SwinModel"),
- ("swin2sr", "Swin2SRModel"),
- ("swinv2", "Swinv2Model"),
- ("switch_transformers", "SwitchTransformersModel"),
- ("t5", "T5Model"),
- ("t5gemma", "T5GemmaModel"),
- ("t5gemma2", "T5Gemma2Model"),
- ("t5gemma2_encoder", "T5Gemma2Encoder"),
- ("table-transformer", "TableTransformerModel"),
- ("tapas", "TapasModel"),
- ("textnet", "TextNetModel"),
- ("time_series_transformer", "TimeSeriesTransformerModel"),
- ("timesfm", "TimesFmModel"),
- ("timesfm2_5", "TimesFm2_5Model"),
- ("timesformer", "TimesformerModel"),
- ("timm_backbone", "TimmBackbone"),
- ("timm_wrapper", "TimmWrapperModel"),
- ("tvp", "TvpModel"),
- ("udop", "UdopModel"),
- ("umt5", "UMT5Model"),
- ("unispeech", "UniSpeechModel"),
- ("unispeech-sat", "UniSpeechSatModel"),
- ("univnet", "UnivNetModel"),
- ("uvdoc", "UVDocModel"),
- ("vaultgemma", "VaultGemmaModel"),
- ("vibevoice_acoustic_tokenizer", "VibeVoiceAcousticTokenizerModel"),
- ("vibevoice_acoustic_tokenizer_decoder", "VibeVoiceAcousticTokenizerDecoderModel"),
- ("vibevoice_acoustic_tokenizer_encoder", "VibeVoiceAcousticTokenizerEncoderModel"),
- ("vibevoice_asr", "VibeVoiceAsrForConditionalGeneration"),
- ("video_llama_3", "VideoLlama3Model"),
- ("video_llama_3_vision", "VideoLlama3VisionModel"),
- ("video_llava", "VideoLlavaModel"),
- ("videomae", "VideoMAEModel"),
- ("vilt", "ViltModel"),
- ("vipllava", "VipLlavaModel"),
- ("vision-text-dual-encoder", "VisionTextDualEncoderModel"),
- ("visual_bert", "VisualBertModel"),
- ("vit", "ViTModel"),
- ("vit_mae", "ViTMAEModel"),
- ("vit_msn", "ViTMSNModel"),
- ("vitdet", "VitDetModel"),
- ("vits", "VitsModel"),
- ("vivit", "VivitModel"),
- ("vjepa2", "VJEPA2Model"),
- ("voxtral", "VoxtralForConditionalGeneration"),
- ("voxtral_encoder", "VoxtralEncoder"),
- ("voxtral_realtime", "VoxtralRealtimeForConditionalGeneration"),
- ("voxtral_realtime_encoder", "VoxtralRealtimeEncoder"),
- ("voxtral_realtime_text", "VoxtralRealtimeTextModel"),
- ("wav2vec2", "Wav2Vec2Model"),
- ("wav2vec2-bert", "Wav2Vec2BertModel"),
- ("wav2vec2-conformer", "Wav2Vec2ConformerModel"),
- ("wavlm", "WavLMModel"),
- ("whisper", "WhisperModel"),
- ("xclip", "XCLIPModel"),
- ("xcodec", "XcodecModel"),
- ("xglm", "XGLMModel"),
- ("xlm", "XLMModel"),
- ("xlm-roberta", "XLMRobertaModel"),
- ("xlm-roberta-xl", "XLMRobertaXLModel"),
- ("xlnet", "XLNetModel"),
- ("xlstm", "xLSTMModel"),
- ("xmod", "XmodModel"),
- ("yolos", "YolosModel"),
- ("yoso", "YosoModel"),
- ("youtu", "YoutuModel"),
- ("zamba", "ZambaModel"),
- ("zamba2", "Zamba2Model"),
- ]
- )
- MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
- [
- # Model for pre-training mapping
- ("albert", "AlbertForPreTraining"),
- ("audioflamingo3", "AudioFlamingo3ForConditionalGeneration"),
- ("bart", "BartForConditionalGeneration"),
- ("bert", "BertForPreTraining"),
- ("big_bird", "BigBirdForPreTraining"),
- ("bloom", "BloomForCausalLM"),
- ("camembert", "CamembertForMaskedLM"),
- ("colmodernvbert", "ColModernVBertForRetrieval"),
- ("colpali", "ColPaliForRetrieval"),
- ("colqwen2", "ColQwen2ForRetrieval"),
- ("ctrl", "CTRLLMHeadModel"),
- ("data2vec-text", "Data2VecTextForMaskedLM"),
- ("deberta", "DebertaForMaskedLM"),
- ("deberta-v2", "DebertaV2ForMaskedLM"),
- ("distilbert", "DistilBertForMaskedLM"),
- ("electra", "ElectraForPreTraining"),
- ("ernie", "ErnieForPreTraining"),
- ("evolla", "EvollaForProteinText2Text"),
- ("exaone4", "Exaone4ForCausalLM"),
- ("exaone_moe", "ExaoneMoeForCausalLM"),
- ("falcon_mamba", "FalconMambaForCausalLM"),
- ("flaubert", "FlaubertWithLMHeadModel"),
- ("flava", "FlavaForPreTraining"),
- ("florence2", "Florence2ForConditionalGeneration"),
- ("fnet", "FNetForPreTraining"),
- ("fsmt", "FSMTForConditionalGeneration"),
- ("funnel", "FunnelForPreTraining"),
- ("gemma3", "Gemma3ForConditionalGeneration"),
- ("gemma4", "Gemma4ForConditionalGeneration"),
- ("glmasr", "GlmAsrForConditionalGeneration"),
- ("gpt-sw3", "GPT2LMHeadModel"),
- ("gpt2", "GPT2LMHeadModel"),
- ("gpt_bigcode", "GPTBigCodeForCausalLM"),
- ("hiera", "HieraForPreTraining"),
- ("ibert", "IBertForMaskedLM"),
- ("idefics", "IdeficsForVisionText2Text"),
- ("idefics2", "Idefics2ForConditionalGeneration"),
- ("idefics3", "Idefics3ForConditionalGeneration"),
- ("janus", "JanusForConditionalGeneration"),
- ("layoutlm", "LayoutLMForMaskedLM"),
- ("llava", "LlavaForConditionalGeneration"),
- ("llava_next", "LlavaNextForConditionalGeneration"),
- ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
- ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
- ("longformer", "LongformerForMaskedLM"),
- ("luke", "LukeForMaskedLM"),
- ("lxmert", "LxmertForPreTraining"),
- ("mamba", "MambaForCausalLM"),
- ("mamba2", "Mamba2ForCausalLM"),
- ("megatron-bert", "MegatronBertForPreTraining"),
- ("mistral3", "Mistral3ForConditionalGeneration"),
- ("mistral4", "Mistral4ForCausalLM"),
- ("mllama", "MllamaForConditionalGeneration"),
- ("mobilebert", "MobileBertForPreTraining"),
- ("mpnet", "MPNetForMaskedLM"),
- ("mpt", "MptForCausalLM"),
- ("mra", "MraForMaskedLM"),
- ("musicflamingo", "MusicFlamingoForConditionalGeneration"),
- ("mvp", "MvpForConditionalGeneration"),
- ("nanochat", "NanoChatForCausalLM"),
- ("nllb-moe", "NllbMoeForConditionalGeneration"),
- ("openai-gpt", "OpenAIGPTLMHeadModel"),
- ("paligemma", "PaliGemmaForConditionalGeneration"),
- ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),
- ("roberta", "RobertaForMaskedLM"),
- ("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"),
- ("roc_bert", "RoCBertForPreTraining"),
- ("rwkv", "RwkvForCausalLM"),
- ("splinter", "SplinterForPreTraining"),
- ("squeezebert", "SqueezeBertForMaskedLM"),
- ("switch_transformers", "SwitchTransformersForConditionalGeneration"),
- ("t5", "T5ForConditionalGeneration"),
- ("t5gemma", "T5GemmaForConditionalGeneration"),
- ("t5gemma2", "T5Gemma2ForConditionalGeneration"),
- ("tapas", "TapasForMaskedLM"),
- ("unispeech", "UniSpeechForPreTraining"),
- ("unispeech-sat", "UniSpeechSatForPreTraining"),
- ("vibevoice_asr", "VibeVoiceAsrForConditionalGeneration"),
- ("video_llava", "VideoLlavaForConditionalGeneration"),
- ("videomae", "VideoMAEForPreTraining"),
- ("vipllava", "VipLlavaForConditionalGeneration"),
- ("visual_bert", "VisualBertForPreTraining"),
- ("vit_mae", "ViTMAEForPreTraining"),
- ("voxtral", "VoxtralForConditionalGeneration"),
- ("voxtral_realtime", "VoxtralRealtimeForConditionalGeneration"),
- ("wav2vec2", "Wav2Vec2ForPreTraining"),
- ("wav2vec2-conformer", "Wav2Vec2ConformerForPreTraining"),
- ("xlm", "XLMWithLMHeadModel"),
- ("xlm-roberta", "XLMRobertaForMaskedLM"),
- ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
- ("xlnet", "XLNetLMHeadModel"),
- ("xlstm", "xLSTMForCausalLM"),
- ("xmod", "XmodForMaskedLM"),
- ]
- )
- MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
- [
- # Model for Causal LM mapping
- ("afmoe", "AfmoeForCausalLM"),
- ("apertus", "ApertusForCausalLM"),
- ("arcee", "ArceeForCausalLM"),
- ("aria_text", "AriaTextForCausalLM"),
- ("bamba", "BambaForCausalLM"),
- ("bart", "BartForCausalLM"),
- ("bert", "BertLMHeadModel"),
- ("bert-generation", "BertGenerationDecoder"),
- ("big_bird", "BigBirdForCausalLM"),
- ("bigbird_pegasus", "BigBirdPegasusForCausalLM"),
- ("biogpt", "BioGptForCausalLM"),
- ("bitnet", "BitNetForCausalLM"),
- ("blenderbot", "BlenderbotForCausalLM"),
- ("blenderbot-small", "BlenderbotSmallForCausalLM"),
- ("bloom", "BloomForCausalLM"),
- ("blt", "BltForCausalLM"),
- ("camembert", "CamembertForCausalLM"),
- ("code_llama", "LlamaForCausalLM"),
- ("codegen", "CodeGenForCausalLM"),
- ("cohere", "CohereForCausalLM"),
- ("cohere2", "Cohere2ForCausalLM"),
- ("cpmant", "CpmAntForCausalLM"),
- ("ctrl", "CTRLLMHeadModel"),
- ("cwm", "CwmForCausalLM"),
- ("data2vec-text", "Data2VecTextForCausalLM"),
- ("dbrx", "DbrxForCausalLM"),
- ("deepseek_v2", "DeepseekV2ForCausalLM"),
- ("deepseek_v3", "DeepseekV3ForCausalLM"),
- ("diffllama", "DiffLlamaForCausalLM"),
- ("doge", "DogeForCausalLM"),
- ("dots1", "Dots1ForCausalLM"),
- ("electra", "ElectraForCausalLM"),
- ("emu3", "Emu3ForCausalLM"),
- ("ernie", "ErnieForCausalLM"),
- ("ernie4_5", "Ernie4_5ForCausalLM"),
- ("ernie4_5_moe", "Ernie4_5_MoeForCausalLM"),
- ("exaone4", "Exaone4ForCausalLM"),
- ("exaone_moe", "ExaoneMoeForCausalLM"),
- ("falcon", "FalconForCausalLM"),
- ("falcon_h1", "FalconH1ForCausalLM"),
- ("falcon_mamba", "FalconMambaForCausalLM"),
- ("flex_olmo", "FlexOlmoForCausalLM"),
- ("fuyu", "FuyuForCausalLM"),
- ("gemma", "GemmaForCausalLM"),
- ("gemma2", "Gemma2ForCausalLM"),
- ("gemma3", "Gemma3ForConditionalGeneration"),
- ("gemma3_text", "Gemma3ForCausalLM"),
- ("gemma3n", "Gemma3nForConditionalGeneration"),
- ("gemma3n_text", "Gemma3nForCausalLM"),
- ("gemma4", "Gemma4ForConditionalGeneration"),
- ("gemma4_text", "Gemma4ForCausalLM"),
- ("git", "GitForCausalLM"),
- ("glm", "GlmForCausalLM"),
- ("glm4", "Glm4ForCausalLM"),
- ("glm4_moe", "Glm4MoeForCausalLM"),
- ("glm4_moe_lite", "Glm4MoeLiteForCausalLM"),
- ("glm_moe_dsa", "GlmMoeDsaForCausalLM"),
- ("got_ocr2", "GotOcr2ForConditionalGeneration"),
- ("gpt-sw3", "GPT2LMHeadModel"),
- ("gpt2", "GPT2LMHeadModel"),
- ("gpt_bigcode", "GPTBigCodeForCausalLM"),
- ("gpt_neo", "GPTNeoForCausalLM"),
- ("gpt_neox", "GPTNeoXForCausalLM"),
- ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
- ("gpt_oss", "GptOssForCausalLM"),
- ("gptj", "GPTJForCausalLM"),
- ("granite", "GraniteForCausalLM"),
- ("granitemoe", "GraniteMoeForCausalLM"),
- ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),
- ("granitemoeshared", "GraniteMoeSharedForCausalLM"),
- ("helium", "HeliumForCausalLM"),
- ("hunyuan_v1_dense", "HunYuanDenseV1ForCausalLM"),
- ("hunyuan_v1_moe", "HunYuanMoEV1ForCausalLM"),
- ("jais2", "Jais2ForCausalLM"),
- ("jamba", "JambaForCausalLM"),
- ("jetmoe", "JetMoeForCausalLM"),
- ("lfm2", "Lfm2ForCausalLM"),
- ("lfm2_moe", "Lfm2MoeForCausalLM"),
- ("llama", "LlamaForCausalLM"),
- ("llama4", "Llama4ForCausalLM"),
- ("llama4_text", "Llama4ForCausalLM"),
- ("longcat_flash", "LongcatFlashForCausalLM"),
- ("mamba", "MambaForCausalLM"),
- ("mamba2", "Mamba2ForCausalLM"),
- ("marian", "MarianForCausalLM"),
- ("mbart", "MBartForCausalLM"),
- ("megatron-bert", "MegatronBertForCausalLM"),
- ("minimax", "MiniMaxForCausalLM"),
- ("minimax_m2", "MiniMaxM2ForCausalLM"),
- ("ministral", "MinistralForCausalLM"),
- ("ministral3", "Ministral3ForCausalLM"),
- ("mistral", "MistralForCausalLM"),
- ("mixtral", "MixtralForCausalLM"),
- ("mllama", "MllamaForCausalLM"),
- ("modernbert-decoder", "ModernBertDecoderForCausalLM"),
- ("moshi", "MoshiForCausalLM"),
- ("mpt", "MptForCausalLM"),
- ("musicgen", "MusicgenForCausalLM"),
- ("musicgen_melody", "MusicgenMelodyForCausalLM"),
- ("mvp", "MvpForCausalLM"),
- ("nanochat", "NanoChatForCausalLM"),
- ("nemotron", "NemotronForCausalLM"),
- ("nemotron_h", "NemotronHForCausalLM"),
- ("olmo", "OlmoForCausalLM"),
- ("olmo2", "Olmo2ForCausalLM"),
- ("olmo3", "Olmo3ForCausalLM"),
- ("olmo_hybrid", "OlmoHybridForCausalLM"),
- ("olmoe", "OlmoeForCausalLM"),
- ("openai-gpt", "OpenAIGPTLMHeadModel"),
- ("opt", "OPTForCausalLM"),
- ("pegasus", "PegasusForCausalLM"),
- ("persimmon", "PersimmonForCausalLM"),
- ("phi", "PhiForCausalLM"),
- ("phi3", "Phi3ForCausalLM"),
- ("phi4_multimodal", "Phi4MultimodalForCausalLM"),
- ("phimoe", "PhimoeForCausalLM"),
- ("plbart", "PLBartForCausalLM"),
- ("prophetnet", "ProphetNetForCausalLM"),
- ("qwen2", "Qwen2ForCausalLM"),
- ("qwen2_moe", "Qwen2MoeForCausalLM"),
- ("qwen3", "Qwen3ForCausalLM"),
- ("qwen3_5", "Qwen3_5ForCausalLM"), # VLM compatibility
- ("qwen3_5_moe", "Qwen3_5MoeForCausalLM"), # VLM compatibility
- ("qwen3_5_moe_text", "Qwen3_5MoeForCausalLM"),
- ("qwen3_5_text", "Qwen3_5ForCausalLM"),
- ("qwen3_moe", "Qwen3MoeForCausalLM"),
- ("qwen3_next", "Qwen3NextForCausalLM"),
- ("recurrent_gemma", "RecurrentGemmaForCausalLM"),
- ("reformer", "ReformerModelWithLMHead"),
- ("rembert", "RemBertForCausalLM"),
- ("roberta", "RobertaForCausalLM"),
- ("roberta-prelayernorm", "RobertaPreLayerNormForCausalLM"),
- ("roc_bert", "RoCBertForCausalLM"),
- ("roformer", "RoFormerForCausalLM"),
- ("rwkv", "RwkvForCausalLM"),
- ("seed_oss", "SeedOssForCausalLM"),
- ("smollm3", "SmolLM3ForCausalLM"),
- ("solar_open", "SolarOpenForCausalLM"),
- ("stablelm", "StableLmForCausalLM"),
- ("starcoder2", "Starcoder2ForCausalLM"),
- ("trocr", "TrOCRForCausalLM"),
- ("vaultgemma", "VaultGemmaForCausalLM"),
- ("whisper", "WhisperForCausalLM"),
- ("xglm", "XGLMForCausalLM"),
- ("xlm", "XLMWithLMHeadModel"),
- ("xlm-roberta", "XLMRobertaForCausalLM"),
- ("xlm-roberta-xl", "XLMRobertaXLForCausalLM"),
- ("xlnet", "XLNetLMHeadModel"),
- ("xlstm", "xLSTMForCausalLM"),
- ("xmod", "XmodForCausalLM"),
- ("youtu", "YoutuForCausalLM"),
- ("zamba", "ZambaForCausalLM"),
- ("zamba2", "Zamba2ForCausalLM"),
- ]
- )
- MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict(
- [
- # Model for Image mapping
- ("aimv2_vision_model", "Aimv2VisionModel"),
- ("beit", "BeitModel"),
- ("bit", "BitModel"),
- ("cohere2_vision", "Cohere2VisionModel"),
- ("conditional_detr", "ConditionalDetrModel"),
- ("convnext", "ConvNextModel"),
- ("convnextv2", "ConvNextV2Model"),
- ("dab-detr", "DabDetrModel"),
- ("data2vec-vision", "Data2VecVisionModel"),
- ("deformable_detr", "DeformableDetrModel"),
- ("deit", "DeiTModel"),
- ("depth_pro", "DepthProModel"),
- ("detr", "DetrModel"),
- ("dinat", "DinatModel"),
- ("dinov2", "Dinov2Model"),
- ("dinov2_with_registers", "Dinov2WithRegistersModel"),
- ("dinov3_convnext", "DINOv3ConvNextModel"),
- ("dinov3_vit", "DINOv3ViTModel"),
- ("dpt", "DPTModel"),
- ("efficientnet", "EfficientNetModel"),
- ("focalnet", "FocalNetModel"),
- ("glpn", "GLPNModel"),
- ("hiera", "HieraModel"),
- ("ijepa", "IJepaModel"),
- ("imagegpt", "ImageGPTModel"),
- ("levit", "LevitModel"),
- ("llama4", "Llama4VisionModel"),
- ("mlcd", "MLCDVisionModel"), # Keep this to make some original hub repositories (from `DeepGlint-AI`) works
- ("mlcd_vision_model", "MLCDVisionModel"),
- ("mllama", "MllamaVisionModel"),
- ("mobilenet_v1", "MobileNetV1Model"),
- ("mobilenet_v2", "MobileNetV2Model"),
- ("mobilevit", "MobileViTModel"),
- ("mobilevitv2", "MobileViTV2Model"),
- ("pixio", "PixioModel"),
- ("poolformer", "PoolFormerModel"),
- ("pvt", "PvtModel"),
- ("regnet", "RegNetModel"),
- ("resnet", "ResNetModel"),
- ("segformer", "SegformerModel"),
- ("siglip_vision_model", "SiglipVisionModel"),
- ("swiftformer", "SwiftFormerModel"),
- ("swin", "SwinModel"),
- ("swin2sr", "Swin2SRModel"),
- ("swinv2", "Swinv2Model"),
- ("table-transformer", "TableTransformerModel"),
- ("timesformer", "TimesformerModel"),
- ("timm_backbone", "TimmBackbone"),
- ("timm_wrapper", "TimmWrapperModel"),
- ("videomae", "VideoMAEModel"),
- ("vit", "ViTModel"),
- ("vit_mae", "ViTMAEModel"),
- ("vit_msn", "ViTMSNModel"),
- ("vitdet", "VitDetModel"),
- ("vivit", "VivitModel"),
- ("yolos", "YolosModel"),
- ]
- )
- MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
- [
- ("deit", "DeiTForMaskedImageModeling"),
- ("focalnet", "FocalNetForMaskedImageModeling"),
- ("swin", "SwinForMaskedImageModeling"),
- ("swinv2", "Swinv2ForMaskedImageModeling"),
- ("vit", "ViTForMaskedImageModeling"),
- ]
- )
- MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
- # Model for Causal Image Modeling mapping
- [
- ("imagegpt", "ImageGPTForCausalImageModeling"),
- ]
- )
- MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
- [
- # Model for Image Classification mapping
- ("beit", "BeitForImageClassification"),
- ("bit", "BitForImageClassification"),
- ("clip", "CLIPForImageClassification"),
- ("convnext", "ConvNextForImageClassification"),
- ("convnextv2", "ConvNextV2ForImageClassification"),
- ("cvt", "CvtForImageClassification"),
- ("data2vec-vision", "Data2VecVisionForImageClassification"),
- (
- "deit",
- ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher"),
- ),
- ("dinat", "DinatForImageClassification"),
- ("dinov2", "Dinov2ForImageClassification"),
- ("dinov2_with_registers", "Dinov2WithRegistersForImageClassification"),
- ("donut-swin", "DonutSwinForImageClassification"),
- ("efficientnet", "EfficientNetForImageClassification"),
- ("focalnet", "FocalNetForImageClassification"),
- ("hgnet_v2", "HGNetV2ForImageClassification"),
- ("hiera", "HieraForImageClassification"),
- ("ijepa", "IJepaForImageClassification"),
- ("imagegpt", "ImageGPTForImageClassification"),
- (
- "levit",
- ("LevitForImageClassification", "LevitForImageClassificationWithTeacher"),
- ),
- ("metaclip_2", "MetaClip2ForImageClassification"),
- ("mobilenet_v1", "MobileNetV1ForImageClassification"),
- ("mobilenet_v2", "MobileNetV2ForImageClassification"),
- ("mobilevit", "MobileViTForImageClassification"),
- ("mobilevitv2", "MobileViTV2ForImageClassification"),
- (
- "perceiver",
- (
- "PerceiverForImageClassificationLearned",
- "PerceiverForImageClassificationFourier",
- "PerceiverForImageClassificationConvProcessing",
- ),
- ),
- ("poolformer", "PoolFormerForImageClassification"),
- ("pp_lcnet", "PPLCNetForImageClassification"),
- ("pvt", "PvtForImageClassification"),
- ("pvt_v2", "PvtV2ForImageClassification"),
- ("regnet", "RegNetForImageClassification"),
- ("resnet", "ResNetForImageClassification"),
- ("segformer", "SegformerForImageClassification"),
- ("shieldgemma2", "ShieldGemma2ForImageClassification"),
- ("siglip", "SiglipForImageClassification"),
- ("siglip2", "Siglip2ForImageClassification"),
- ("swiftformer", "SwiftFormerForImageClassification"),
- ("swin", "SwinForImageClassification"),
- ("swinv2", "Swinv2ForImageClassification"),
- ("textnet", "TextNetForImageClassification"),
- ("timm_wrapper", "TimmWrapperForImageClassification"),
- ("vit", "ViTForImageClassification"),
- ("vit_msn", "ViTMSNForImageClassification"),
- ]
- )
- MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
- [
- # Do not add new models here, this class will be deprecated in the future.
- # Model for Image Segmentation mapping
- ("detr", "DetrForSegmentation"),
- ]
- )
- MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
- [
- # Model for Semantic Segmentation mapping
- ("beit", "BeitForSemanticSegmentation"),
- ("data2vec-vision", "Data2VecVisionForSemanticSegmentation"),
- ("dpt", "DPTForSemanticSegmentation"),
- ("mobilenet_v2", "MobileNetV2ForSemanticSegmentation"),
- ("mobilevit", "MobileViTForSemanticSegmentation"),
- ("mobilevitv2", "MobileViTV2ForSemanticSegmentation"),
- ("segformer", "SegformerForSemanticSegmentation"),
- ("upernet", "UperNetForSemanticSegmentation"),
- ]
- )
- MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
- [
- # Model for Instance Segmentation mapping
- # MaskFormerForInstanceSegmentation can be removed from this mapping in v5
- ("maskformer", "MaskFormerForInstanceSegmentation"),
- ]
- )
- MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES = OrderedDict(
- [
- # Model for Universal Segmentation mapping
- ("detr", "DetrForSegmentation"),
- ("eomt", "EomtForUniversalSegmentation"),
- ("eomt_dinov3", "EomtDinov3ForUniversalSegmentation"),
- ("mask2former", "Mask2FormerForUniversalSegmentation"),
- ("maskformer", "MaskFormerForInstanceSegmentation"),
- ("oneformer", "OneFormerForUniversalSegmentation"),
- ("videomt", "VideomtForUniversalSegmentation"),
- ]
- )
- MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
- [
- ("timesformer", "TimesformerForVideoClassification"),
- ("videomae", "VideoMAEForVideoClassification"),
- ("vivit", "VivitForVideoClassification"),
- ("vjepa2", "VJEPA2ForVideoClassification"),
- ]
- )
- MODEL_FOR_RETRIEVAL_MAPPING_NAMES = OrderedDict(
- [
- ("colmodernvbert", "ColModernVBertForRetrieval"),
- ("colpali", "ColPaliForRetrieval"),
- ]
- )
- MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
- [
- ("aria", "AriaForConditionalGeneration"),
- ("aya_vision", "AyaVisionForConditionalGeneration"),
- ("blip", "BlipForConditionalGeneration"),
- ("blip-2", "Blip2ForConditionalGeneration"),
- ("chameleon", "ChameleonForConditionalGeneration"),
- ("cohere2_vision", "Cohere2VisionForConditionalGeneration"),
- ("deepseek_vl", "DeepseekVLForConditionalGeneration"),
- ("deepseek_vl_hybrid", "DeepseekVLHybridForConditionalGeneration"),
- ("emu3", "Emu3ForConditionalGeneration"),
- ("ernie4_5_vl_moe", "Ernie4_5_VLMoeForConditionalGeneration"),
- ("evolla", "EvollaForProteinText2Text"),
- ("fast_vlm", "FastVlmForConditionalGeneration"),
- ("florence2", "Florence2ForConditionalGeneration"),
- ("fuyu", "FuyuForCausalLM"),
- ("gemma3", "Gemma3ForConditionalGeneration"),
- ("gemma3n", "Gemma3nForConditionalGeneration"),
- ("gemma4", "Gemma4ForConditionalGeneration"),
- ("git", "GitForCausalLM"),
- ("glm46v", "Glm46VForConditionalGeneration"),
- ("glm4v", "Glm4vForConditionalGeneration"),
- ("glm4v_moe", "Glm4vMoeForConditionalGeneration"),
- ("glm_ocr", "GlmOcrForConditionalGeneration"),
- ("got_ocr2", "GotOcr2ForConditionalGeneration"),
- ("idefics", "IdeficsForVisionText2Text"),
- ("idefics2", "Idefics2ForConditionalGeneration"),
- ("idefics3", "Idefics3ForConditionalGeneration"),
- ("instructblip", "InstructBlipForConditionalGeneration"),
- ("instructblipvideo", "InstructBlipVideoForConditionalGeneration"),
- ("internvl", "InternVLForConditionalGeneration"),
- ("janus", "JanusForConditionalGeneration"),
- ("kosmos-2", "Kosmos2ForConditionalGeneration"),
- ("kosmos-2.5", "Kosmos2_5ForConditionalGeneration"),
- ("lfm2_vl", "Lfm2VlForConditionalGeneration"),
- ("lighton_ocr", "LightOnOcrForConditionalGeneration"),
- ("llama4", "Llama4ForConditionalGeneration"),
- ("llava", "LlavaForConditionalGeneration"),
- ("llava_next", "LlavaNextForConditionalGeneration"),
- ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
- ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
- ("mistral3", "Mistral3ForConditionalGeneration"),
- ("mistral4", "Mistral4ForCausalLM"),
- ("mllama", "MllamaForConditionalGeneration"),
- ("ovis2", "Ovis2ForConditionalGeneration"),
- ("paddleocr_vl", "PaddleOCRVLForConditionalGeneration"),
- ("paligemma", "PaliGemmaForConditionalGeneration"),
- ("perception_lm", "PerceptionLMForConditionalGeneration"),
- ("pi0", "PI0ForConditionalGeneration"),
- ("pix2struct", "Pix2StructForConditionalGeneration"),
- ("pixtral", "LlavaForConditionalGeneration"),
- ("pp_chart2table", "GotOcr2ForConditionalGeneration"),
- ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),
- ("qwen2_vl", "Qwen2VLForConditionalGeneration"),
- ("qwen3_5", "Qwen3_5ForConditionalGeneration"),
- ("qwen3_5_moe", "Qwen3_5MoeForConditionalGeneration"),
- ("qwen3_vl", "Qwen3VLForConditionalGeneration"),
- ("qwen3_vl_moe", "Qwen3VLMoeForConditionalGeneration"),
- ("shieldgemma2", "Gemma3ForConditionalGeneration"),
- ("smolvlm", "SmolVLMForConditionalGeneration"),
- ("t5gemma2", "T5Gemma2ForConditionalGeneration"),
- ("udop", "UdopForConditionalGeneration"),
- ("video_llama_3", "VideoLlama3ForConditionalGeneration"),
- ("video_llava", "VideoLlavaForConditionalGeneration"),
- ("vipllava", "VipLlavaForConditionalGeneration"),
- ("vision-encoder-decoder", "VisionEncoderDecoderModel"),
- ]
- )
- # Models that accept text and optionally multimodal data in inputs
- # and can generate text and optionally multimodal data.
- MODEL_FOR_MULTIMODAL_LM_MAPPING_NAMES = OrderedDict(
- [
- *list(MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.items()),
- ("glmasr", "GlmAsrForConditionalGeneration"),
- ("granite_speech", "GraniteSpeechForConditionalGeneration"),
- ("kyutai_speech_to_text", "KyutaiSpeechToTextForConditionalGeneration"),
- ("phi4_multimodal", "Phi4MultimodalForCausalLM"),
- ("qwen2_5_omni", "Qwen2_5OmniForConditionalGeneration"),
- ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),
- ("qwen3_omni_moe", "Qwen3OmniMoeForConditionalGeneration"),
- ("vibevoice_asr", "VibeVoiceAsrForConditionalGeneration"),
- ("voxtral", "VoxtralForConditionalGeneration"),
- ("voxtral_realtime", "VoxtralRealtimeForConditionalGeneration"),
- ]
- )
- MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
- [
- # Model for Masked LM mapping
- ("albert", "AlbertForMaskedLM"),
- ("bart", "BartForConditionalGeneration"),
- ("bert", "BertForMaskedLM"),
- ("big_bird", "BigBirdForMaskedLM"),
- ("camembert", "CamembertForMaskedLM"),
- ("convbert", "ConvBertForMaskedLM"),
- ("data2vec-text", "Data2VecTextForMaskedLM"),
- ("deberta", "DebertaForMaskedLM"),
- ("deberta-v2", "DebertaV2ForMaskedLM"),
- ("distilbert", "DistilBertForMaskedLM"),
- ("electra", "ElectraForMaskedLM"),
- ("ernie", "ErnieForMaskedLM"),
- ("esm", "EsmForMaskedLM"),
- ("eurobert", "EuroBertForMaskedLM"),
- ("flaubert", "FlaubertWithLMHeadModel"),
- ("fnet", "FNetForMaskedLM"),
- ("funnel", "FunnelForMaskedLM"),
- ("ibert", "IBertForMaskedLM"),
- ("jina_embeddings_v3", "JinaEmbeddingsV3ForMaskedLM"),
- ("layoutlm", "LayoutLMForMaskedLM"),
- ("longformer", "LongformerForMaskedLM"),
- ("luke", "LukeForMaskedLM"),
- ("mbart", "MBartForConditionalGeneration"),
- ("megatron-bert", "MegatronBertForMaskedLM"),
- ("mobilebert", "MobileBertForMaskedLM"),
- ("modernbert", "ModernBertForMaskedLM"),
- ("modernvbert", "ModernVBertForMaskedLM"),
- ("mpnet", "MPNetForMaskedLM"),
- ("mra", "MraForMaskedLM"),
- ("mvp", "MvpForConditionalGeneration"),
- ("nomic_bert", "NomicBertForMaskedLM"),
- ("nystromformer", "NystromformerForMaskedLM"),
- ("perceiver", "PerceiverForMaskedLM"),
- ("reformer", "ReformerForMaskedLM"),
- ("rembert", "RemBertForMaskedLM"),
- ("roberta", "RobertaForMaskedLM"),
- ("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"),
- ("roc_bert", "RoCBertForMaskedLM"),
- ("roformer", "RoFormerForMaskedLM"),
- ("squeezebert", "SqueezeBertForMaskedLM"),
- ("tapas", "TapasForMaskedLM"),
- ("xlm", "XLMWithLMHeadModel"),
- ("xlm-roberta", "XLMRobertaForMaskedLM"),
- ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
- ("xmod", "XmodForMaskedLM"),
- ("yoso", "YosoForMaskedLM"),
- ]
- )
- MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
- [
- # Model for Object Detection mapping
- ("conditional_detr", "ConditionalDetrForObjectDetection"),
- ("d_fine", "DFineForObjectDetection"),
- ("dab-detr", "DabDetrForObjectDetection"),
- ("deformable_detr", "DeformableDetrForObjectDetection"),
- ("detr", "DetrForObjectDetection"),
- ("lw_detr", "LwDetrForObjectDetection"),
- ("pp_doclayout_v2", "PPDocLayoutV2ForObjectDetection"),
- ("pp_doclayout_v3", "PPDocLayoutV3ForObjectDetection"),
- ("pp_ocrv5_mobile_det", "PPOCRV5MobileDetForObjectDetection"),
- ("pp_ocrv5_server_det", "PPOCRV5ServerDetForObjectDetection"),
- ("rt_detr", "RTDetrForObjectDetection"),
- ("rt_detr_v2", "RTDetrV2ForObjectDetection"),
- ("table-transformer", "TableTransformerForObjectDetection"),
- ("yolos", "YolosForObjectDetection"),
- ]
- )
- MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
- [
- # Model for Zero Shot Object Detection mapping
- ("grounding-dino", "GroundingDinoForObjectDetection"),
- ("mm-grounding-dino", "MMGroundingDinoForObjectDetection"),
- ("omdet-turbo", "OmDetTurboForObjectDetection"),
- ("owlv2", "Owlv2ForObjectDetection"),
- ("owlvit", "OwlViTForObjectDetection"),
- ]
- )
- MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = OrderedDict(
- [
- # Model for depth estimation mapping
- ("chmv2", "CHMv2ForDepthEstimation"),
- ("depth_anything", "DepthAnythingForDepthEstimation"),
- ("depth_pro", "DepthProForDepthEstimation"),
- ("dpt", "DPTForDepthEstimation"),
- ("glpn", "GLPNForDepthEstimation"),
- ("prompt_depth_anything", "PromptDepthAnythingForDepthEstimation"),
- ("zoedepth", "ZoeDepthForDepthEstimation"),
- ]
- )
- MODEL_FOR_TEXT_RECOGNITION_MAPPING_NAMES = OrderedDict(
- [
- ("pp_ocrv5_mobile_rec", "PPOCRV5MobileRecForTextRecognition"),
- ("pp_ocrv5_server_rec", "PPOCRV5ServerRecForTextRecognition"),
- ]
- )
- MODEL_FOR_TABLE_RECOGNITION_MAPPING_NAMES = OrderedDict(
- [
- ("slanext", "SLANeXtForTableRecognition"),
- ]
- )
- MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
- [
- # Model for Seq2Seq Causal LM mapping
- ("audioflamingo3", "AudioFlamingo3ForConditionalGeneration"),
- ("bart", "BartForConditionalGeneration"),
- ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
- ("blenderbot", "BlenderbotForConditionalGeneration"),
- ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
- ("encoder-decoder", "EncoderDecoderModel"),
- ("fsmt", "FSMTForConditionalGeneration"),
- ("glmasr", "GlmAsrForConditionalGeneration"),
- ("granite_speech", "GraniteSpeechForConditionalGeneration"),
- ("led", "LEDForConditionalGeneration"),
- ("longt5", "LongT5ForConditionalGeneration"),
- ("m2m_100", "M2M100ForConditionalGeneration"),
- ("marian", "MarianMTModel"),
- ("mbart", "MBartForConditionalGeneration"),
- ("mt5", "MT5ForConditionalGeneration"),
- ("musicflamingo", "MusicFlamingoForConditionalGeneration"),
- ("mvp", "MvpForConditionalGeneration"),
- ("nllb-moe", "NllbMoeForConditionalGeneration"),
- ("pegasus", "PegasusForConditionalGeneration"),
- ("pegasus_x", "PegasusXForConditionalGeneration"),
- ("plbart", "PLBartForConditionalGeneration"),
- ("prophetnet", "ProphetNetForConditionalGeneration"),
- ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),
- ("seamless_m4t", "SeamlessM4TForTextToText"),
- ("seamless_m4t_v2", "SeamlessM4Tv2ForTextToText"),
- ("switch_transformers", "SwitchTransformersForConditionalGeneration"),
- ("t5", "T5ForConditionalGeneration"),
- ("t5gemma", "T5GemmaForConditionalGeneration"),
- ("t5gemma2", "T5Gemma2ForConditionalGeneration"),
- ("umt5", "UMT5ForConditionalGeneration"),
- ("vibevoice_asr", "VibeVoiceAsrForConditionalGeneration"),
- ("voxtral", "VoxtralForConditionalGeneration"),
- ("voxtral_realtime", "VoxtralRealtimeForConditionalGeneration"),
- ]
- )
- MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
- [
- ("cohere_asr", "CohereAsrForConditionalGeneration"),
- ("dia", "DiaForConditionalGeneration"),
- ("granite_speech", "GraniteSpeechForConditionalGeneration"),
- ("kyutai_speech_to_text", "KyutaiSpeechToTextForConditionalGeneration"),
- ("moonshine", "MoonshineForConditionalGeneration"),
- ("moonshine_streaming", "MoonshineStreamingForConditionalGeneration"),
- ("pop2piano", "Pop2PianoForConditionalGeneration"),
- ("seamless_m4t", "SeamlessM4TForSpeechToText"),
- ("seamless_m4t_v2", "SeamlessM4Tv2ForSpeechToText"),
- ("speech-encoder-decoder", "SpeechEncoderDecoderModel"),
- ("speech_to_text", "Speech2TextForConditionalGeneration"),
- ("speecht5", "SpeechT5ForSpeechToText"),
- ("vibevoice_asr", "VibeVoiceAsrForConditionalGeneration"),
- ("voxtral", "VoxtralForConditionalGeneration"),
- ("voxtral_realtime", "VoxtralRealtimeForConditionalGeneration"),
- ("whisper", "WhisperForConditionalGeneration"),
- ]
- )
- MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
- [
- # Model for Sequence Classification mapping
- ("albert", "AlbertForSequenceClassification"),
- ("arcee", "ArceeForSequenceClassification"),
- ("bart", "BartForSequenceClassification"),
- ("bert", "BertForSequenceClassification"),
- ("big_bird", "BigBirdForSequenceClassification"),
- ("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"),
- ("biogpt", "BioGptForSequenceClassification"),
- ("bloom", "BloomForSequenceClassification"),
- ("camembert", "CamembertForSequenceClassification"),
- ("canine", "CanineForSequenceClassification"),
- ("code_llama", "LlamaForSequenceClassification"),
- ("convbert", "ConvBertForSequenceClassification"),
- ("ctrl", "CTRLForSequenceClassification"),
- ("data2vec-text", "Data2VecTextForSequenceClassification"),
- ("deberta", "DebertaForSequenceClassification"),
- ("deberta-v2", "DebertaV2ForSequenceClassification"),
- ("deepseek_v2", "DeepseekV2ForSequenceClassification"),
- ("deepseek_v3", "DeepseekV3ForSequenceClassification"),
- ("diffllama", "DiffLlamaForSequenceClassification"),
- ("distilbert", "DistilBertForSequenceClassification"),
- ("doge", "DogeForSequenceClassification"),
- ("electra", "ElectraForSequenceClassification"),
- ("ernie", "ErnieForSequenceClassification"),
- ("esm", "EsmForSequenceClassification"),
- ("eurobert", "EuroBertForSequenceClassification"),
- ("exaone4", "Exaone4ForSequenceClassification"),
- ("falcon", "FalconForSequenceClassification"),
- ("flaubert", "FlaubertForSequenceClassification"),
- ("fnet", "FNetForSequenceClassification"),
- ("funnel", "FunnelForSequenceClassification"),
- ("gemma", "GemmaForSequenceClassification"),
- ("gemma2", "Gemma2ForSequenceClassification"),
- ("gemma3", "Gemma3ForSequenceClassification"),
- ("gemma3_text", "Gemma3TextForSequenceClassification"),
- ("glm", "GlmForSequenceClassification"),
- ("glm4", "Glm4ForSequenceClassification"),
- ("gpt-sw3", "GPT2ForSequenceClassification"),
- ("gpt2", "GPT2ForSequenceClassification"),
- ("gpt_bigcode", "GPTBigCodeForSequenceClassification"),
- ("gpt_neo", "GPTNeoForSequenceClassification"),
- ("gpt_neox", "GPTNeoXForSequenceClassification"),
- ("gpt_oss", "GptOssForSequenceClassification"),
- ("gptj", "GPTJForSequenceClassification"),
- ("helium", "HeliumForSequenceClassification"),
- ("hunyuan_v1_dense", "HunYuanDenseV1ForSequenceClassification"),
- ("hunyuan_v1_moe", "HunYuanMoEV1ForSequenceClassification"),
- ("ibert", "IBertForSequenceClassification"),
- ("jamba", "JambaForSequenceClassification"),
- ("jetmoe", "JetMoeForSequenceClassification"),
- ("jina_embeddings_v3", "JinaEmbeddingsV3ForSequenceClassification"),
- ("layoutlm", "LayoutLMForSequenceClassification"),
- ("layoutlmv2", "LayoutLMv2ForSequenceClassification"),
- ("layoutlmv3", "LayoutLMv3ForSequenceClassification"),
- ("lilt", "LiltForSequenceClassification"),
- ("llama", "LlamaForSequenceClassification"),
- ("longformer", "LongformerForSequenceClassification"),
- ("luke", "LukeForSequenceClassification"),
- ("markuplm", "MarkupLMForSequenceClassification"),
- ("mbart", "MBartForSequenceClassification"),
- ("megatron-bert", "MegatronBertForSequenceClassification"),
- ("minimax", "MiniMaxForSequenceClassification"),
- ("ministral", "MinistralForSequenceClassification"),
- ("ministral3", "Ministral3ForSequenceClassification"),
- ("mistral", "MistralForSequenceClassification"),
- ("mistral4", "Mistral4ForSequenceClassification"),
- ("mixtral", "MixtralForSequenceClassification"),
- ("mobilebert", "MobileBertForSequenceClassification"),
- ("modernbert", "ModernBertForSequenceClassification"),
- ("modernbert-decoder", "ModernBertDecoderForSequenceClassification"),
- ("modernvbert", "ModernVBertForSequenceClassification"),
- ("mpnet", "MPNetForSequenceClassification"),
- ("mpt", "MptForSequenceClassification"),
- ("mra", "MraForSequenceClassification"),
- ("mt5", "MT5ForSequenceClassification"),
- ("mvp", "MvpForSequenceClassification"),
- ("nemotron", "NemotronForSequenceClassification"),
- ("nomic_bert", "NomicBertForSequenceClassification"),
- ("nystromformer", "NystromformerForSequenceClassification"),
- ("openai-gpt", "OpenAIGPTForSequenceClassification"),
- ("opt", "OPTForSequenceClassification"),
- ("perceiver", "PerceiverForSequenceClassification"),
- ("persimmon", "PersimmonForSequenceClassification"),
- ("phi", "PhiForSequenceClassification"),
- ("phi3", "Phi3ForSequenceClassification"),
- ("phimoe", "PhimoeForSequenceClassification"),
- ("plbart", "PLBartForSequenceClassification"),
- ("qwen2", "Qwen2ForSequenceClassification"),
- ("qwen2_moe", "Qwen2MoeForSequenceClassification"),
- ("qwen3", "Qwen3ForSequenceClassification"),
- ("qwen3_5", "Qwen3_5ForSequenceClassification"),
- ("qwen3_5_text", "Qwen3_5ForSequenceClassification"),
- ("qwen3_moe", "Qwen3MoeForSequenceClassification"),
- ("qwen3_next", "Qwen3NextForSequenceClassification"),
- ("reformer", "ReformerForSequenceClassification"),
- ("rembert", "RemBertForSequenceClassification"),
- ("roberta", "RobertaForSequenceClassification"),
- ("roberta-prelayernorm", "RobertaPreLayerNormForSequenceClassification"),
- ("roc_bert", "RoCBertForSequenceClassification"),
- ("roformer", "RoFormerForSequenceClassification"),
- ("seed_oss", "SeedOssForSequenceClassification"),
- ("smollm3", "SmolLM3ForSequenceClassification"),
- ("squeezebert", "SqueezeBertForSequenceClassification"),
- ("stablelm", "StableLmForSequenceClassification"),
- ("starcoder2", "Starcoder2ForSequenceClassification"),
- ("t5", "T5ForSequenceClassification"),
- ("t5gemma", "T5GemmaForSequenceClassification"),
- ("t5gemma2", "T5Gemma2ForSequenceClassification"),
- ("tapas", "TapasForSequenceClassification"),
- ("umt5", "UMT5ForSequenceClassification"),
- ("xlm", "XLMForSequenceClassification"),
- ("xlm-roberta", "XLMRobertaForSequenceClassification"),
- ("xlm-roberta-xl", "XLMRobertaXLForSequenceClassification"),
- ("xlnet", "XLNetForSequenceClassification"),
- ("xmod", "XmodForSequenceClassification"),
- ("yoso", "YosoForSequenceClassification"),
- ("zamba", "ZambaForSequenceClassification"),
- ("zamba2", "Zamba2ForSequenceClassification"),
- ]
- )
- MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
- [
- # Model for Question Answering mapping
- ("albert", "AlbertForQuestionAnswering"),
- ("arcee", "ArceeForQuestionAnswering"),
- ("bart", "BartForQuestionAnswering"),
- ("bert", "BertForQuestionAnswering"),
- ("big_bird", "BigBirdForQuestionAnswering"),
- ("bigbird_pegasus", "BigBirdPegasusForQuestionAnswering"),
- ("bloom", "BloomForQuestionAnswering"),
- ("camembert", "CamembertForQuestionAnswering"),
- ("canine", "CanineForQuestionAnswering"),
- ("convbert", "ConvBertForQuestionAnswering"),
- ("data2vec-text", "Data2VecTextForQuestionAnswering"),
- ("deberta", "DebertaForQuestionAnswering"),
- ("deberta-v2", "DebertaV2ForQuestionAnswering"),
- ("diffllama", "DiffLlamaForQuestionAnswering"),
- ("distilbert", "DistilBertForQuestionAnswering"),
- ("electra", "ElectraForQuestionAnswering"),
- ("ernie", "ErnieForQuestionAnswering"),
- ("exaone4", "Exaone4ForQuestionAnswering"),
- ("falcon", "FalconForQuestionAnswering"),
- ("flaubert", "FlaubertForQuestionAnsweringSimple"),
- ("fnet", "FNetForQuestionAnswering"),
- ("funnel", "FunnelForQuestionAnswering"),
- ("gpt2", "GPT2ForQuestionAnswering"),
- ("gpt_neo", "GPTNeoForQuestionAnswering"),
- ("gpt_neox", "GPTNeoXForQuestionAnswering"),
- ("gptj", "GPTJForQuestionAnswering"),
- ("ibert", "IBertForQuestionAnswering"),
- ("jina_embeddings_v3", "JinaEmbeddingsV3ForQuestionAnswering"),
- ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
- ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
- ("led", "LEDForQuestionAnswering"),
- ("lilt", "LiltForQuestionAnswering"),
- ("llama", "LlamaForQuestionAnswering"),
- ("longformer", "LongformerForQuestionAnswering"),
- ("luke", "LukeForQuestionAnswering"),
- ("lxmert", "LxmertForQuestionAnswering"),
- ("markuplm", "MarkupLMForQuestionAnswering"),
- ("mbart", "MBartForQuestionAnswering"),
- ("megatron-bert", "MegatronBertForQuestionAnswering"),
- ("minimax", "MiniMaxForQuestionAnswering"),
- ("ministral", "MinistralForQuestionAnswering"),
- ("ministral3", "Ministral3ForQuestionAnswering"),
- ("mistral", "MistralForQuestionAnswering"),
- ("mixtral", "MixtralForQuestionAnswering"),
- ("mobilebert", "MobileBertForQuestionAnswering"),
- ("modernbert", "ModernBertForQuestionAnswering"),
- ("mpnet", "MPNetForQuestionAnswering"),
- ("mpt", "MptForQuestionAnswering"),
- ("mra", "MraForQuestionAnswering"),
- ("mt5", "MT5ForQuestionAnswering"),
- ("mvp", "MvpForQuestionAnswering"),
- ("nemotron", "NemotronForQuestionAnswering"),
- ("nystromformer", "NystromformerForQuestionAnswering"),
- ("opt", "OPTForQuestionAnswering"),
- ("qwen2", "Qwen2ForQuestionAnswering"),
- ("qwen2_moe", "Qwen2MoeForQuestionAnswering"),
- ("qwen3", "Qwen3ForQuestionAnswering"),
- ("qwen3_moe", "Qwen3MoeForQuestionAnswering"),
- ("qwen3_next", "Qwen3NextForQuestionAnswering"),
- ("reformer", "ReformerForQuestionAnswering"),
- ("rembert", "RemBertForQuestionAnswering"),
- ("roberta", "RobertaForQuestionAnswering"),
- ("roberta-prelayernorm", "RobertaPreLayerNormForQuestionAnswering"),
- ("roc_bert", "RoCBertForQuestionAnswering"),
- ("roformer", "RoFormerForQuestionAnswering"),
- ("seed_oss", "SeedOssForQuestionAnswering"),
- ("smollm3", "SmolLM3ForQuestionAnswering"),
- ("splinter", "SplinterForQuestionAnswering"),
- ("squeezebert", "SqueezeBertForQuestionAnswering"),
- ("t5", "T5ForQuestionAnswering"),
- ("umt5", "UMT5ForQuestionAnswering"),
- ("xlm", "XLMForQuestionAnsweringSimple"),
- ("xlm-roberta", "XLMRobertaForQuestionAnswering"),
- ("xlm-roberta-xl", "XLMRobertaXLForQuestionAnswering"),
- ("xlnet", "XLNetForQuestionAnsweringSimple"),
- ("xmod", "XmodForQuestionAnswering"),
- ("yoso", "YosoForQuestionAnswering"),
- ]
- )
- MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
- [
- # Model for Table Question Answering mapping
- ("tapas", "TapasForQuestionAnswering"),
- ]
- )
- MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
- [
- ("blip", "BlipForQuestionAnswering"),
- ("blip-2", "Blip2ForConditionalGeneration"),
- ("vilt", "ViltForQuestionAnswering"),
- ]
- )
- MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
- [
- ("layoutlm", "LayoutLMForQuestionAnswering"),
- ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
- ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
- ]
- )
- MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
- [
- # Model for Token Classification mapping
- ("albert", "AlbertForTokenClassification"),
- ("apertus", "ApertusForTokenClassification"),
- ("arcee", "ArceeForTokenClassification"),
- ("bert", "BertForTokenClassification"),
- ("big_bird", "BigBirdForTokenClassification"),
- ("biogpt", "BioGptForTokenClassification"),
- ("bloom", "BloomForTokenClassification"),
- ("bros", "BrosForTokenClassification"),
- ("camembert", "CamembertForTokenClassification"),
- ("canine", "CanineForTokenClassification"),
- ("convbert", "ConvBertForTokenClassification"),
- ("data2vec-text", "Data2VecTextForTokenClassification"),
- ("deberta", "DebertaForTokenClassification"),
- ("deberta-v2", "DebertaV2ForTokenClassification"),
- ("deepseek_v3", "DeepseekV3ForTokenClassification"),
- ("diffllama", "DiffLlamaForTokenClassification"),
- ("distilbert", "DistilBertForTokenClassification"),
- ("electra", "ElectraForTokenClassification"),
- ("ernie", "ErnieForTokenClassification"),
- ("esm", "EsmForTokenClassification"),
- ("eurobert", "EuroBertForTokenClassification"),
- ("exaone4", "Exaone4ForTokenClassification"),
- ("falcon", "FalconForTokenClassification"),
- ("flaubert", "FlaubertForTokenClassification"),
- ("fnet", "FNetForTokenClassification"),
- ("funnel", "FunnelForTokenClassification"),
- ("gemma", "GemmaForTokenClassification"),
- ("gemma2", "Gemma2ForTokenClassification"),
- ("glm", "GlmForTokenClassification"),
- ("glm4", "Glm4ForTokenClassification"),
- ("gpt-sw3", "GPT2ForTokenClassification"),
- ("gpt2", "GPT2ForTokenClassification"),
- ("gpt_bigcode", "GPTBigCodeForTokenClassification"),
- ("gpt_neo", "GPTNeoForTokenClassification"),
- ("gpt_neox", "GPTNeoXForTokenClassification"),
- ("gpt_oss", "GptOssForTokenClassification"),
- ("helium", "HeliumForTokenClassification"),
- ("ibert", "IBertForTokenClassification"),
- ("jina_embeddings_v3", "JinaEmbeddingsV3ForTokenClassification"),
- ("layoutlm", "LayoutLMForTokenClassification"),
- ("layoutlmv2", "LayoutLMv2ForTokenClassification"),
- ("layoutlmv3", "LayoutLMv3ForTokenClassification"),
- ("lilt", "LiltForTokenClassification"),
- ("llama", "LlamaForTokenClassification"),
- ("longformer", "LongformerForTokenClassification"),
- ("luke", "LukeForTokenClassification"),
- ("markuplm", "MarkupLMForTokenClassification"),
- ("megatron-bert", "MegatronBertForTokenClassification"),
- ("minimax", "MiniMaxForTokenClassification"),
- ("ministral", "MinistralForTokenClassification"),
- ("ministral3", "Ministral3ForTokenClassification"),
- ("mistral", "MistralForTokenClassification"),
- ("mistral4", "Mistral4ForTokenClassification"),
- ("mixtral", "MixtralForTokenClassification"),
- ("mobilebert", "MobileBertForTokenClassification"),
- ("modernbert", "ModernBertForTokenClassification"),
- ("modernvbert", "ModernVBertForTokenClassification"),
- ("mpnet", "MPNetForTokenClassification"),
- ("mpt", "MptForTokenClassification"),
- ("mra", "MraForTokenClassification"),
- ("mt5", "MT5ForTokenClassification"),
- ("nemotron", "NemotronForTokenClassification"),
- ("nomic_bert", "NomicBertForTokenClassification"),
- ("nystromformer", "NystromformerForTokenClassification"),
- ("persimmon", "PersimmonForTokenClassification"),
- ("phi", "PhiForTokenClassification"),
- ("phi3", "Phi3ForTokenClassification"),
- ("qwen2", "Qwen2ForTokenClassification"),
- ("qwen2_moe", "Qwen2MoeForTokenClassification"),
- ("qwen3", "Qwen3ForTokenClassification"),
- ("qwen3_moe", "Qwen3MoeForTokenClassification"),
- ("qwen3_next", "Qwen3NextForTokenClassification"),
- ("rembert", "RemBertForTokenClassification"),
- ("roberta", "RobertaForTokenClassification"),
- ("roberta-prelayernorm", "RobertaPreLayerNormForTokenClassification"),
- ("roc_bert", "RoCBertForTokenClassification"),
- ("roformer", "RoFormerForTokenClassification"),
- ("seed_oss", "SeedOssForTokenClassification"),
- ("smollm3", "SmolLM3ForTokenClassification"),
- ("squeezebert", "SqueezeBertForTokenClassification"),
- ("stablelm", "StableLmForTokenClassification"),
- ("starcoder2", "Starcoder2ForTokenClassification"),
- ("t5", "T5ForTokenClassification"),
- ("t5gemma", "T5GemmaForTokenClassification"),
- ("t5gemma2", "T5Gemma2ForTokenClassification"),
- ("umt5", "UMT5ForTokenClassification"),
- ("xlm", "XLMForTokenClassification"),
- ("xlm-roberta", "XLMRobertaForTokenClassification"),
- ("xlm-roberta-xl", "XLMRobertaXLForTokenClassification"),
- ("xlnet", "XLNetForTokenClassification"),
- ("xmod", "XmodForTokenClassification"),
- ("yoso", "YosoForTokenClassification"),
- ]
- )
- MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
- [
- # Model for Multiple Choice mapping
- ("albert", "AlbertForMultipleChoice"),
- ("bert", "BertForMultipleChoice"),
- ("big_bird", "BigBirdForMultipleChoice"),
- ("camembert", "CamembertForMultipleChoice"),
- ("canine", "CanineForMultipleChoice"),
- ("convbert", "ConvBertForMultipleChoice"),
- ("data2vec-text", "Data2VecTextForMultipleChoice"),
- ("deberta-v2", "DebertaV2ForMultipleChoice"),
- ("distilbert", "DistilBertForMultipleChoice"),
- ("electra", "ElectraForMultipleChoice"),
- ("ernie", "ErnieForMultipleChoice"),
- ("flaubert", "FlaubertForMultipleChoice"),
- ("fnet", "FNetForMultipleChoice"),
- ("funnel", "FunnelForMultipleChoice"),
- ("ibert", "IBertForMultipleChoice"),
- ("longformer", "LongformerForMultipleChoice"),
- ("luke", "LukeForMultipleChoice"),
- ("megatron-bert", "MegatronBertForMultipleChoice"),
- ("mobilebert", "MobileBertForMultipleChoice"),
- ("modernbert", "ModernBertForMultipleChoice"),
- ("mpnet", "MPNetForMultipleChoice"),
- ("mra", "MraForMultipleChoice"),
- ("nystromformer", "NystromformerForMultipleChoice"),
- ("rembert", "RemBertForMultipleChoice"),
- ("roberta", "RobertaForMultipleChoice"),
- ("roberta-prelayernorm", "RobertaPreLayerNormForMultipleChoice"),
- ("roc_bert", "RoCBertForMultipleChoice"),
- ("roformer", "RoFormerForMultipleChoice"),
- ("squeezebert", "SqueezeBertForMultipleChoice"),
- ("xlm", "XLMForMultipleChoice"),
- ("xlm-roberta", "XLMRobertaForMultipleChoice"),
- ("xlm-roberta-xl", "XLMRobertaXLForMultipleChoice"),
- ("xlnet", "XLNetForMultipleChoice"),
- ("xmod", "XmodForMultipleChoice"),
- ("yoso", "YosoForMultipleChoice"),
- ]
- )
- MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
- [
- ("bert", "BertForNextSentencePrediction"),
- ("ernie", "ErnieForNextSentencePrediction"),
- ("fnet", "FNetForNextSentencePrediction"),
- ("megatron-bert", "MegatronBertForNextSentencePrediction"),
- ("mobilebert", "MobileBertForNextSentencePrediction"),
- ]
- )
- MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
- [
- # Model for Audio Classification mapping
- ("audio-spectrogram-transformer", "ASTForAudioClassification"),
- ("data2vec-audio", "Data2VecAudioForSequenceClassification"),
- ("hubert", "HubertForSequenceClassification"),
- ("sew", "SEWForSequenceClassification"),
- ("sew-d", "SEWDForSequenceClassification"),
- ("unispeech", "UniSpeechForSequenceClassification"),
- ("unispeech-sat", "UniSpeechSatForSequenceClassification"),
- ("wav2vec2", "Wav2Vec2ForSequenceClassification"),
- ("wav2vec2-bert", "Wav2Vec2BertForSequenceClassification"),
- ("wav2vec2-conformer", "Wav2Vec2ConformerForSequenceClassification"),
- ("wavlm", "WavLMForSequenceClassification"),
- ("whisper", "WhisperForAudioClassification"),
- ]
- )
- MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict(
- [
- # Model for Connectionist temporal classification (CTC) mapping
- ("data2vec-audio", "Data2VecAudioForCTC"),
- ("hubert", "HubertForCTC"),
- ("lasr_ctc", "LasrForCTC"),
- ("parakeet_ctc", "ParakeetForCTC"),
- ("sew", "SEWForCTC"),
- ("sew-d", "SEWDForCTC"),
- ("unispeech", "UniSpeechForCTC"),
- ("unispeech-sat", "UniSpeechSatForCTC"),
- ("wav2vec2", "Wav2Vec2ForCTC"),
- ("wav2vec2-bert", "Wav2Vec2BertForCTC"),
- ("wav2vec2-conformer", "Wav2Vec2ConformerForCTC"),
- ("wavlm", "WavLMForCTC"),
- ]
- )
- MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
- [
- # Model for Audio Classification mapping
- ("data2vec-audio", "Data2VecAudioForAudioFrameClassification"),
- ("unispeech-sat", "UniSpeechSatForAudioFrameClassification"),
- ("wav2vec2", "Wav2Vec2ForAudioFrameClassification"),
- ("wav2vec2-bert", "Wav2Vec2BertForAudioFrameClassification"),
- ("wav2vec2-conformer", "Wav2Vec2ConformerForAudioFrameClassification"),
- ("wavlm", "WavLMForAudioFrameClassification"),
- ]
- )
- MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict(
- [
- # Model for Audio Classification mapping
- ("data2vec-audio", "Data2VecAudioForXVector"),
- ("unispeech-sat", "UniSpeechSatForXVector"),
- ("wav2vec2", "Wav2Vec2ForXVector"),
- ("wav2vec2-bert", "Wav2Vec2BertForXVector"),
- ("wav2vec2-conformer", "Wav2Vec2ConformerForXVector"),
- ("wavlm", "WavLMForXVector"),
- ]
- )
- MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = OrderedDict(
- [
- # Model for Text-To-Spectrogram mapping
- ("fastspeech2_conformer", "FastSpeech2ConformerModel"),
- ("speecht5", "SpeechT5ForTextToSpeech"),
- ]
- )
- MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = OrderedDict(
- [
- # Model for Text-To-Waveform mapping
- ("bark", "BarkModel"),
- ("csm", "CsmForConditionalGeneration"),
- ("fastspeech2_conformer", "FastSpeech2ConformerWithHifiGan"),
- ("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"),
- ("higgs_audio_v2", "HiggsAudioV2ForConditionalGeneration"),
- ("musicgen", "MusicgenForConditionalGeneration"),
- ("musicgen_melody", "MusicgenMelodyForConditionalGeneration"),
- ("qwen2_5_omni", "Qwen2_5OmniForConditionalGeneration"),
- ("qwen3_omni_moe", "Qwen3OmniMoeForConditionalGeneration"),
- ("seamless_m4t", "SeamlessM4TForTextToSpeech"),
- ("seamless_m4t_v2", "SeamlessM4Tv2ForTextToSpeech"),
- ("vits", "VitsModel"),
- ]
- )
- MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
- [
- # Model for Zero Shot Image Classification mapping
- ("align", "AlignModel"),
- ("altclip", "AltCLIPModel"),
- ("blip", "BlipModel"),
- ("blip-2", "Blip2ForImageTextRetrieval"),
- ("chinese_clip", "ChineseCLIPModel"),
- ("clip", "CLIPModel"),
- ("clipseg", "CLIPSegModel"),
- ("metaclip_2", "MetaClip2Model"),
- ("siglip", "SiglipModel"),
- ("siglip2", "Siglip2Model"),
- ]
- )
- MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
- [
- # Backbone mapping
- ("beit", "BeitBackbone"),
- ("bit", "BitBackbone"),
- ("convnext", "ConvNextBackbone"),
- ("convnextv2", "ConvNextV2Backbone"),
- ("dinat", "DinatBackbone"),
- ("dinov2", "Dinov2Backbone"),
- ("dinov2_with_registers", "Dinov2WithRegistersBackbone"),
- ("dinov3_convnext", "DINOv3ConvNextBackbone"),
- ("dinov3_vit", "DINOv3ViTBackbone"),
- ("focalnet", "FocalNetBackbone"),
- ("hgnet_v2", "HGNetV2Backbone"),
- ("hiera", "HieraBackbone"),
- ("lw_detr_vit", "LwDetrViTBackbone"),
- ("maskformer-swin", "MaskFormerSwinBackbone"),
- ("pixio", "PixioBackbone"),
- ("pp_lcnet", "PPLCNetBackbone"),
- ("pp_lcnet_v3", "PPLCNetV3Backbone"),
- ("pvt_v2", "PvtV2Backbone"),
- ("resnet", "ResNetBackbone"),
- ("rt_detr_resnet", "RTDetrResNetBackbone"),
- ("swin", "SwinBackbone"),
- ("swinv2", "Swinv2Backbone"),
- ("textnet", "TextNetBackbone"),
- ("timm_backbone", "TimmBackbone"),
- ("uvdoc_backbone", "UVDocBackbone"),
- ("vitdet", "VitDetBackbone"),
- ("vitpose_backbone", "VitPoseBackbone"),
- ]
- )
- MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
- [
- ("edgetam", "EdgeTamModel"),
- ("edgetam_video", "EdgeTamModel"),
- ("sam", "SamModel"),
- ("sam2", "Sam2Model"),
- ("sam2_video", "Sam2Model"),
- ("sam3_tracker", "Sam3TrackerModel"),
- ("sam3_video", "Sam3TrackerModel"),
- ("sam_hq", "SamHQModel"),
- ]
- )
- MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES = OrderedDict(
- [
- ("superpoint", "SuperPointForKeypointDetection"),
- ]
- )
- MODEL_FOR_KEYPOINT_MATCHING_MAPPING_NAMES = OrderedDict(
- [
- ("efficientloftr", "EfficientLoFTRForKeypointMatching"),
- ("lightglue", "LightGlueForKeypointMatching"),
- ("superglue", "SuperGlueForKeypointMatching"),
- ]
- )
- MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
- [
- ("albert", "AlbertModel"),
- ("bert", "BertModel"),
- ("big_bird", "BigBirdModel"),
- ("clip_text_model", "CLIPTextModel"),
- ("data2vec-text", "Data2VecTextModel"),
- ("deberta", "DebertaModel"),
- ("deberta-v2", "DebertaV2Model"),
- ("distilbert", "DistilBertModel"),
- ("electra", "ElectraModel"),
- ("emu3", "Emu3TextModel"),
- ("flaubert", "FlaubertModel"),
- ("ibert", "IBertModel"),
- ("llama4", "Llama4TextModel"),
- ("longformer", "LongformerModel"),
- ("mllama", "MllamaTextModel"),
- ("mobilebert", "MobileBertModel"),
- ("mt5", "MT5EncoderModel"),
- ("nystromformer", "NystromformerModel"),
- ("reformer", "ReformerModel"),
- ("rembert", "RemBertModel"),
- ("roberta", "RobertaModel"),
- ("roberta-prelayernorm", "RobertaPreLayerNormModel"),
- ("roc_bert", "RoCBertModel"),
- ("roformer", "RoFormerModel"),
- ("squeezebert", "SqueezeBertModel"),
- ("t5", "T5EncoderModel"),
- ("t5gemma", "T5GemmaEncoderModel"),
- ("umt5", "UMT5EncoderModel"),
- ("xlm", "XLMModel"),
- ("xlm-roberta", "XLMRobertaModel"),
- ("xlm-roberta-xl", "XLMRobertaXLModel"),
- ]
- )
- MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
- [
- ("patchtsmixer", "PatchTSMixerForTimeSeriesClassification"),
- ("patchtst", "PatchTSTForClassification"),
- ]
- )
- MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES = OrderedDict(
- [
- ("patchtsmixer", "PatchTSMixerForRegression"),
- ("patchtst", "PatchTSTForRegression"),
- ]
- )
- MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING_NAMES = OrderedDict(
- [
- ("timesfm", "TimesFmModelForPrediction"),
- ("timesfm2_5", "TimesFm2_5ModelForPrediction"),
- ]
- )
- MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = OrderedDict(
- [
- ("swin2sr", "Swin2SRForImageSuperResolution"),
- ]
- )
- MODEL_FOR_AUDIO_TOKENIZATION_NAMES = OrderedDict(
- [
- ("dac", "DacModel"),
- ("higgs_audio_v2_tokenizer", "HiggsAudioV2TokenizerModel"),
- ("vibevoice_acoustic_tokenizer", "VibeVoiceAcousticTokenizerModel"),
- ]
- )
- MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
- MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
- MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
- MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES
- )
- MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
- )
- MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
- )
- MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES
- )
- MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
- )
- MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES
- )
- MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES
- )
- MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES
- )
- MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
- )
- MODEL_FOR_MULTIMODAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIMODAL_LM_MAPPING_NAMES)
- MODEL_FOR_RETRIEVAL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_RETRIEVAL_MAPPING_NAMES)
- MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES
- )
- MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES
- )
- MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
- MODEL_FOR_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_MAPPING_NAMES)
- MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
- )
- MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES)
- MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES
- )
- MODEL_FOR_DEPTH_ESTIMATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES)
- MODEL_FOR_TEXT_RECOGNITION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_RECOGNITION_MAPPING_NAMES)
- MODEL_FOR_TABLE_RECOGNITION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_RECOGNITION_MAPPING_NAMES)
- MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
- )
- MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
- )
- MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
- )
- MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
- )
- MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
- )
- MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES)
- MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
- )
- MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
- )
- MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES)
- MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES)
- MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES
- )
- MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES)
- MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES
- )
- MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES)
- MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BACKBONE_MAPPING_NAMES)
- MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES)
- MODEL_FOR_KEYPOINT_DETECTION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES
- )
- MODEL_FOR_KEYPOINT_MATCHING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_KEYPOINT_MATCHING_MAPPING_NAMES)
- MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES)
- MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING_NAMES
- )
- MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES
- )
- MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING_NAMES
- )
- MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES)
- MODEL_FOR_AUDIO_TOKENIZATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_TOKENIZATION_NAMES)
- class AutoModelForMaskGeneration(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING
- class AutoModelForKeypointDetection(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_KEYPOINT_DETECTION_MAPPING
- class AutoModelForKeypointMatching(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_KEYPOINT_MATCHING_MAPPING
- class AutoModelForTextEncoding(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING
- class AutoModelForImageToImage(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_IMAGE_TO_IMAGE_MAPPING
- class AutoModel(_BaseAutoModelClass):
- _model_mapping = MODEL_MAPPING
- AutoModel = auto_class_update(AutoModel)
- class AutoModelForPreTraining(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_PRETRAINING_MAPPING
- AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining")
- class AutoModelForCausalLM(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
- # override to give better return typehint
- @classmethod
- def from_pretrained(
- cls: type["AutoModelForCausalLM"],
- pretrained_model_name_or_path: str | os.PathLike[str],
- *model_args,
- **kwargs,
- ) -> "_BaseModelWithGenerate":
- return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
- AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")
- class AutoModelForMaskedLM(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_MASKED_LM_MAPPING
- AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling")
- class AutoModelForSeq2SeqLM(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
- AutoModelForSeq2SeqLM = auto_class_update(
- AutoModelForSeq2SeqLM,
- head_doc="sequence-to-sequence language modeling",
- checkpoint_for_example="google-t5/t5-base",
- )
- class AutoModelForSequenceClassification(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
- AutoModelForSequenceClassification = auto_class_update(
- AutoModelForSequenceClassification, head_doc="sequence classification"
- )
- class AutoModelForQuestionAnswering(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
- AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering")
- class AutoModelForTableQuestionAnswering(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
- AutoModelForTableQuestionAnswering = auto_class_update(
- AutoModelForTableQuestionAnswering,
- head_doc="table question answering",
- checkpoint_for_example="google/tapas-base-finetuned-wtq",
- )
- class AutoModelForVisualQuestionAnswering(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING
- AutoModelForVisualQuestionAnswering = auto_class_update(
- AutoModelForVisualQuestionAnswering,
- head_doc="visual question answering",
- checkpoint_for_example="dandelin/vilt-b32-finetuned-vqa",
- )
- class AutoModelForDocumentQuestionAnswering(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING
- AutoModelForDocumentQuestionAnswering = auto_class_update(
- AutoModelForDocumentQuestionAnswering,
- head_doc="document question answering",
- checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3',
- )
- class AutoModelForTokenClassification(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
- AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification")
- class AutoModelForMultipleChoice(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING
- AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice")
- class AutoModelForNextSentencePrediction(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
- AutoModelForNextSentencePrediction = auto_class_update(
- AutoModelForNextSentencePrediction, head_doc="next sentence prediction"
- )
- class AutoModelForImageClassification(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
- AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification")
- class AutoModelForZeroShotImageClassification(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
- AutoModelForZeroShotImageClassification = auto_class_update(
- AutoModelForZeroShotImageClassification, head_doc="zero-shot image classification"
- )
- class AutoModelForImageSegmentation(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING
- AutoModelForImageSegmentation = auto_class_update(AutoModelForImageSegmentation, head_doc="image segmentation")
- class AutoModelForSemanticSegmentation(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
- AutoModelForSemanticSegmentation = auto_class_update(
- AutoModelForSemanticSegmentation, head_doc="semantic segmentation"
- )
- class AutoModelForTimeSeriesPrediction(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING
- AutoModelForTimeSeriesPrediction = auto_class_update(
- AutoModelForTimeSeriesPrediction, head_doc="time-series prediction"
- )
- class AutoModelForUniversalSegmentation(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING
- AutoModelForUniversalSegmentation = auto_class_update(
- AutoModelForUniversalSegmentation, head_doc="universal image segmentation"
- )
- class AutoModelForInstanceSegmentation(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING
- AutoModelForInstanceSegmentation = auto_class_update(
- AutoModelForInstanceSegmentation, head_doc="instance segmentation"
- )
- class AutoModelForObjectDetection(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING
- AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection")
- class AutoModelForZeroShotObjectDetection(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
- AutoModelForZeroShotObjectDetection = auto_class_update(
- AutoModelForZeroShotObjectDetection, head_doc="zero-shot object detection"
- )
- class AutoModelForDepthEstimation(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING
- AutoModelForDepthEstimation = auto_class_update(AutoModelForDepthEstimation, head_doc="depth estimation")
- class AutoModelForTextRecognition(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_TEXT_RECOGNITION_MAPPING
- AutoModelForTextRecognition = auto_class_update(AutoModelForTextRecognition, head_doc="text recognition")
- class AutoModelForTableRecognition(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_TABLE_RECOGNITION_MAPPING
- AutoModelForTableRecognition = auto_class_update(AutoModelForTableRecognition, head_doc="table recognition")
- class AutoModelForVideoClassification(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING
- AutoModelForVideoClassification = auto_class_update(AutoModelForVideoClassification, head_doc="video classification")
- class AutoModelForImageTextToText(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING
- # override to give better return typehint
- @classmethod
- def from_pretrained(
- cls: type["AutoModelForImageTextToText"],
- pretrained_model_name_or_path: str | os.PathLike[str],
- *model_args,
- **kwargs,
- ) -> "_BaseModelWithGenerate":
- return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
- AutoModelForImageTextToText = auto_class_update(AutoModelForImageTextToText, head_doc="image-text-to-text modeling")
- class AutoModelForMultimodalLM(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_MULTIMODAL_LM_MAPPING
- AutoModelForMultimodalLM = auto_class_update(AutoModelForMultimodalLM, head_doc="multimodal generation")
- class AutoModelForAudioClassification(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
- AutoModelForAudioClassification = auto_class_update(AutoModelForAudioClassification, head_doc="audio classification")
- class AutoModelForCTC(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_CTC_MAPPING
- AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification")
- class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
- AutoModelForSpeechSeq2Seq = auto_class_update(
- AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
- )
- class AutoModelForAudioFrameClassification(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING
- AutoModelForAudioFrameClassification = auto_class_update(
- AutoModelForAudioFrameClassification, head_doc="audio frame (token) classification"
- )
- class AutoModelForAudioXVector(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING
- class AutoModelForTextToSpectrogram(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING
- class AutoModelForTextToWaveform(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING
- class AutoBackbone(_BaseAutoBackboneClass):
- _model_mapping = MODEL_FOR_BACKBONE_MAPPING
- AutoModelForAudioXVector = auto_class_update(AutoModelForAudioXVector, head_doc="audio retrieval via x-vector")
- class AutoModelForMaskedImageModeling(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING
- AutoModelForMaskedImageModeling = auto_class_update(AutoModelForMaskedImageModeling, head_doc="masked image modeling")
- class AutoModelForAudioTokenization(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_AUDIO_TOKENIZATION_MAPPING
- AutoModelForAudioTokenization = auto_class_update(
- AutoModelForAudioTokenization, head_doc="audio tokenization through codebooks"
- )
- __all__ = [
- "MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
- "MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING",
- "MODEL_FOR_AUDIO_TOKENIZATION_MAPPING",
- "MODEL_FOR_AUDIO_XVECTOR_MAPPING",
- "MODEL_FOR_BACKBONE_MAPPING",
- "MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING",
- "MODEL_FOR_CAUSAL_LM_MAPPING",
- "MODEL_FOR_CTC_MAPPING",
- "MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING",
- "MODEL_FOR_DEPTH_ESTIMATION_MAPPING",
- "MODEL_FOR_TEXT_RECOGNITION_MAPPING",
- "MODEL_FOR_TABLE_RECOGNITION_MAPPING",
- "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
- "MODEL_FOR_IMAGE_MAPPING",
- "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING",
- "MODEL_FOR_IMAGE_TO_IMAGE_MAPPING",
- "MODEL_FOR_KEYPOINT_DETECTION_MAPPING",
- "MODEL_FOR_KEYPOINT_MATCHING_MAPPING",
- "MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING",
- "MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
- "MODEL_FOR_MASKED_LM_MAPPING",
- "MODEL_FOR_MASK_GENERATION_MAPPING",
- "MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
- "MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
- "MODEL_FOR_OBJECT_DETECTION_MAPPING",
- "MODEL_FOR_PRETRAINING_MAPPING",
- "MODEL_FOR_QUESTION_ANSWERING_MAPPING",
- "MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING",
- "MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
- "MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
- "MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
- "MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
- "MODEL_FOR_TEXT_ENCODING_MAPPING",
- "MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING",
- "MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING",
- "MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING",
- "MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
- "MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING",
- "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING",
- "MODEL_FOR_RETRIEVAL_MAPPING",
- "MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING",
- "MODEL_FOR_MULTIMODAL_LM_MAPPING",
- "MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
- "MODEL_MAPPING",
- "MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING",
- "MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING",
- "MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING",
- "MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING",
- "AutoModel",
- "AutoBackbone",
- "AutoModelForAudioClassification",
- "AutoModelForAudioFrameClassification",
- "AutoModelForAudioTokenization",
- "AutoModelForAudioXVector",
- "AutoModelForCausalLM",
- "AutoModelForCTC",
- "AutoModelForDepthEstimation",
- "AutoModelForTextRecognition",
- "AutoModelForTableRecognition",
- "AutoModelForImageClassification",
- "AutoModelForImageSegmentation",
- "AutoModelForImageToImage",
- "AutoModelForInstanceSegmentation",
- "AutoModelForKeypointDetection",
- "AutoModelForKeypointMatching",
- "AutoModelForMaskGeneration",
- "AutoModelForTextEncoding",
- "AutoModelForMaskedImageModeling",
- "AutoModelForMaskedLM",
- "AutoModelForMultipleChoice",
- "AutoModelForMultimodalLM",
- "AutoModelForNextSentencePrediction",
- "AutoModelForObjectDetection",
- "AutoModelForPreTraining",
- "AutoModelForQuestionAnswering",
- "AutoModelForSemanticSegmentation",
- "AutoModelForSeq2SeqLM",
- "AutoModelForSequenceClassification",
- "AutoModelForSpeechSeq2Seq",
- "AutoModelForTableQuestionAnswering",
- "AutoModelForTextToSpectrogram",
- "AutoModelForTextToWaveform",
- "AutoModelForTimeSeriesPrediction",
- "AutoModelForTokenClassification",
- "AutoModelForUniversalSegmentation",
- "AutoModelForVideoClassification",
- "AutoModelForVisualQuestionAnswering",
- "AutoModelForDocumentQuestionAnswering",
- "AutoModelForZeroShotImageClassification",
- "AutoModelForZeroShotObjectDetection",
- "AutoModelForImageTextToText",
- ]
|