auto_docstring.py 161 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109411041114112411341144115411641174118411941204121412241234124412541264127412841294130413141324133413441354136413741384139414041414142414341444145414641474148414941504151415241534154415541564157415841594160416141624163416441654166416741684169417041714172417341744175417641774178417941804181418241834184418541864187418841894190419141924193419441954196419741984199420042014202420342044205420642074208420942104211421242134214421542164217421842194220422142224223422442254226422742284229423042314232423342344235423642374238423942404241424242434244424542464247424842494250425142524253425442554256425742584259426042614262426342644265426642674268426942704271427242734274427542764277427842794280428142824283428442854286428742884289429042914292429342944295429642974298429943004301430243034304430543064307430843094310431143124313431443154316431743184319432043214322432343244325432643274328432943304331433243334334433543364337433843394340434143424343434443454346434743484349435043514352435343544355435643574358435943604361436243634364436543664367436843694370437143724373437443754376437743784379438043814382438343844385438643874388438943904391439243934394439543964397439843994400440144024403440444054406440744084409441044114412441344144415441644174418441944204421442244234424442544264427442844294430443144324433443444354436443744384439444044414442444344444445444644474448444944504451445244534454445544564457445844594460446144624463446444654466446744684469447044714472447344744475447644774478447944804481448244834484448544864487448844894490449144924493449444954496449744984499450045014502450345044505450645074508450945104511451245134514
  1. # Copyright 2025 HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import annotations
  15. import inspect
  16. import os
  17. from collections.abc import Mapping
  18. from functools import lru_cache
  19. from pathlib import Path
  20. from types import UnionType
  21. from typing import ClassVar, Union, get_args, get_origin
  22. import regex as re
  23. import typing_extensions
  24. from .doc import (
  25. MODELS_TO_PIPELINE,
  26. PIPELINE_TASKS_TO_SAMPLE_DOCSTRINGS,
  27. PT_SAMPLE_DOCSTRINGS,
  28. )
  29. from .generic import ModelOutput
  30. PATH_TO_TRANSFORMERS = Path("src").resolve() / "transformers"
  31. AUTODOC_FILES = [
  32. "configuration_*.py",
  33. "modeling_*.py",
  34. "tokenization_*.py",
  35. "processing_*.py",
  36. "image_processing_pil_*.py",
  37. "image_processing_*.py",
  38. "feature_extractor_*.py",
  39. ]
  40. PLACEHOLDER_TO_AUTO_MODULE = {
  41. "image_processor_class": ("image_processing_auto", "IMAGE_PROCESSOR_MAPPING_NAMES"),
  42. "tokenizer_class": ("tokenization_auto", "TOKENIZER_MAPPING_NAMES"),
  43. "video_processor_class": ("video_processing_auto", "VIDEO_PROCESSOR_MAPPING_NAMES"),
  44. "feature_extractor_class": ("feature_extraction_auto", "FEATURE_EXTRACTOR_MAPPING_NAMES"),
  45. "processor_class": ("processing_auto", "PROCESSOR_MAPPING_NAMES"),
  46. "config_class": ("configuration_auto", "CONFIG_MAPPING_NAMES"),
  47. "model_class": ("modeling_auto", "MODEL_MAPPING_NAMES"),
  48. }
  49. UNROLL_KWARGS_METHODS = {
  50. "preprocess",
  51. "__call__",
  52. }
  53. UNROLL_KWARGS_CLASSES = {
  54. "BaseImageProcessor",
  55. "ProcessorMixin",
  56. }
  57. BASIC_KWARGS_TYPES = ["TextKwargs", "ImagesKwargs", "VideosKwargs", "AudioKwargs"]
  58. # Short indicator added to unrolled kwargs to distinguish them from regular args
  59. KWARGS_INDICATOR = ", *kwargs*"
  60. HARDCODED_CONFIG_FOR_MODELS = {
  61. "openai": "OpenAIGPTConfig",
  62. "x-clip": "XCLIPConfig",
  63. "kosmos2": "Kosmos2Config",
  64. "kosmos2-5": "Kosmos2_5Config",
  65. "donut": "DonutSwinConfig",
  66. "esmfold": "EsmConfig",
  67. "parakeet": "ParakeetCTCConfig",
  68. "lasr": "LasrCTCConfig",
  69. "wav2vec2-with-lm": "Wav2Vec2Config",
  70. }
  71. _re_checkpoint = re.compile(r"\[(.+?)\]\((https://huggingface\.co/.+?)\)")
  72. # Pre-compiled patterns used repeatedly at runtime. Compiling once here avoids
  73. # repeated compilation overhead (and cache lookups) on every decorator call.
  74. _re_example_or_return = re.compile(r"(?m)^([ \t]*)(?=Example|Return|```)")
  75. _re_return = re.compile(r"(?m)^([ \t]*)(?=Return)")
  76. _re_example = re.compile(r"(?m)^([ \t]*)(?=Example|```)")
  77. _re_args_section = re.compile(r"(?:Args:)(\n.*)?(\n)?$", re.DOTALL)
  78. _re_shape = re.compile(r"(of shape\s*(?:`.*?`|\(.*?\)))")
  79. _re_default = re.compile(r"(defaults to \s*[^)]*)")
  80. _re_param = re.compile(
  81. r"^\s{0,0}(\w+)\s*\(\s*([^, \)]*)(\s*.*?)\s*\)\s*:\s*((?:(?!\n^\s{0,0}\w+\s*\().)*)",
  82. re.DOTALL | re.MULTILINE,
  83. )
  84. _re_forward_ref = re.compile(r"ForwardRef\('([\w.]+)'\)")
  85. _re_optional = re.compile(r"Optional\[(.*?)\]")
  86. _re_placeholders = re.compile(r"{(.*?)}")
  87. _re_model_task = None # built lazily because PT_SAMPLE_DOCSTRINGS isn't available yet
  88. class ImageProcessorArgs:
  89. images = {
  90. "description": """
  91. Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
  92. passing in images with pixel values between 0 and 1, set `do_rescale=False`.
  93. """,
  94. "shape": None,
  95. }
  96. videos = {
  97. "description": """
  98. Video to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If
  99. passing in videos with pixel values between 0 and 1, set `do_rescale=False`.
  100. """,
  101. "shape": None,
  102. }
  103. do_resize = {
  104. "description": """
  105. Whether to resize the image.
  106. """,
  107. "shape": None,
  108. }
  109. size = {
  110. "description": """
  111. Describes the maximum input dimensions to the model.
  112. """,
  113. "shape": None,
  114. }
  115. size_divisor = {
  116. "description": """
  117. The size by which to make sure both the height and width can be divided.
  118. """,
  119. "shape": None,
  120. }
  121. default_to_square = {
  122. "description": """
  123. Whether to default to a square image when resizing, if size is an int.
  124. """,
  125. "shape": None,
  126. }
  127. resample = {
  128. "description": """
  129. Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
  130. has an effect if `do_resize` is set to `True`.
  131. """,
  132. "shape": None,
  133. }
  134. do_center_crop = {
  135. "description": """
  136. Whether to center crop the image.
  137. """,
  138. "shape": None,
  139. }
  140. crop_size = {
  141. "description": """
  142. Size of the output image after applying `center_crop`.
  143. """,
  144. "shape": None,
  145. }
  146. do_pad = {
  147. "description": """
  148. Whether to pad the image. Padding is done either to the largest size in the batch
  149. or to a fixed square size per image. The exact padding strategy depends on the model.
  150. """,
  151. "shape": None,
  152. }
  153. pad_size = {
  154. "description": """
  155. The size in `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
  156. provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
  157. height and width in the batch. Applied only when `do_pad=True.`
  158. """,
  159. "shape": None,
  160. }
  161. do_rescale = {
  162. "description": """
  163. Whether to rescale the image.
  164. """,
  165. "shape": None,
  166. }
  167. rescale_factor = {
  168. "description": """
  169. Rescale factor to rescale the image by if `do_rescale` is set to `True`.
  170. """,
  171. "shape": None,
  172. }
  173. do_normalize = {
  174. "description": """
  175. Whether to normalize the image.
  176. """,
  177. "shape": None,
  178. }
  179. image_mean = {
  180. "description": """
  181. Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
  182. """,
  183. "shape": None,
  184. }
  185. image_std = {
  186. "description": """
  187. Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
  188. `True`.
  189. """,
  190. "shape": None,
  191. }
  192. do_convert_rgb = {
  193. "description": """
  194. Whether to convert the image to RGB.
  195. """,
  196. "shape": None,
  197. }
  198. return_tensors = {
  199. "description": """
  200. Returns stacked tensors if set to `'pt'`, otherwise returns a list of tensors.
  201. """,
  202. "shape": None,
  203. }
  204. data_format = {
  205. "description": """
  206. Only `ChannelDimension.FIRST` is supported. Added for compatibility with slow processors.
  207. """,
  208. "shape": None,
  209. }
  210. input_data_format = {
  211. "description": """
  212. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  213. from the input image. Can be one of:
  214. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  215. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  216. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
  217. """,
  218. "shape": None,
  219. }
  220. device = {
  221. "description": """
  222. The device to process the images on. If unset, the device is inferred from the input images.
  223. """,
  224. "shape": None,
  225. }
  226. disable_grouping = {
  227. "description": """
  228. Whether to disable grouping of images by size to process them individually and not in batches.
  229. If None, will be set to True if the images are on CPU, and False otherwise. This choice is based on
  230. empirical observations, as detailed here: https://github.com/huggingface/transformers/pull/38157
  231. """,
  232. "shape": None,
  233. }
  234. image_seq_length = {
  235. "description": """
  236. The number of image tokens to be used for each image in the input.
  237. Added for backward compatibility but this should be set as a processor attribute in future models.
  238. """,
  239. "shape": None,
  240. }
  241. # Used for the **kwargs summary line when unrolling typed kwargs (key: "__kwargs__")
  242. __kwargs__ = {
  243. "description": """
  244. Additional image preprocessing options. Model-specific kwargs are listed above; see the TypedDict class
  245. for the complete list of supported arguments.
  246. """,
  247. "shape": None,
  248. }
  249. class ProcessorArgs:
  250. # __init__ arguments
  251. image_processor = {
  252. "description": """
  253. The image processor is a required input.
  254. """,
  255. "type": "{image_processor_class}",
  256. }
  257. tokenizer = {
  258. "description": """
  259. The tokenizer is a required input.
  260. """,
  261. "type": "{tokenizer_class}",
  262. }
  263. video_processor = {
  264. "description": """
  265. The video processor is a required input.
  266. """,
  267. "type": "{video_processor_class}",
  268. }
  269. audio_processor = {
  270. "description": """
  271. The audio processor is a required input.
  272. """,
  273. "type": "{audio_processor_class}",
  274. }
  275. feature_extractor = {
  276. "description": """
  277. The feature extractor is a required input.
  278. """,
  279. "type": "{feature_extractor_class}",
  280. }
  281. chat_template = {
  282. "description": """
  283. A Jinja template to convert lists of messages in a chat into a tokenizable string.
  284. """,
  285. "type": "str",
  286. }
  287. # __call__ arguments
  288. text = {
  289. "description": """
  290. The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
  291. (pretokenized string). If you pass a pretokenized input, set `is_split_into_words=True` to avoid ambiguity with batched inputs.
  292. """,
  293. }
  294. audio = {
  295. "description": """
  296. The audio or batch of audios to be prepared. Each audio can be a NumPy array or PyTorch tensor.
  297. In case of a NumPy array/PyTorch tensor, each audio should be of shape (C, T), where C is a number of channels,
  298. and T is the sample length of the audio.
  299. """,
  300. }
  301. audios = {
  302. "description": """
  303. The audio or batch of audios to be prepared. Each audio can be a NumPy array or PyTorch tensor.
  304. In case of a NumPy array/PyTorch tensor, each audio should be of shape (C, T), where C is a number of channels,
  305. and T is the sample length of the audio.
  306. """,
  307. }
  308. return_tensors = {
  309. "description": """
  310. If set, will return tensors of a particular framework. Acceptable values are:
  311. - `'pt'`: Return PyTorch `torch.Tensor` objects.
  312. - `'np'`: Return NumPy `np.ndarray` objects.
  313. """,
  314. "shape": None,
  315. }
  316. # Standard tokenizer arguments
  317. add_special_tokens = {
  318. "description": """
  319. Whether or not to add special tokens when encoding the sequences. This will use the underlying
  320. [`PretrainedTokenizerBase.build_inputs_with_special_tokens`] function, which defines which tokens are
  321. automatically added to the input ids. This is useful if you want to add `bos` or `eos` tokens
  322. automatically.
  323. """,
  324. "type": "bool",
  325. }
  326. padding = {
  327. "description": """
  328. Activates and controls padding. Accepts the following values:
  329. - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
  330. sequence is provided).
  331. - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
  332. acceptable input length for the model if that argument is not provided.
  333. - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
  334. lengths).
  335. """,
  336. "type": "bool, str or [`~utils.PaddingStrategy`]",
  337. }
  338. truncation = {
  339. "description": """
  340. Activates and controls truncation. Accepts the following values:
  341. - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or
  342. to the maximum acceptable input length for the model if that argument is not provided. This will
  343. truncate token by token, removing a token from the longest sequence in the pair if a pair of
  344. sequences (or a batch of pairs) is provided.
  345. - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
  346. maximum acceptable input length for the model if that argument is not provided. This will only
  347. truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
  348. - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
  349. maximum acceptable input length for the model if that argument is not provided. This will only
  350. truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
  351. - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
  352. greater than the model maximum admissible input size).
  353. """,
  354. "type": "bool, str or [`~tokenization_utils_base.TruncationStrategy`]",
  355. }
  356. max_length = {
  357. "description": """
  358. Controls the maximum length to use by one of the truncation/padding parameters.
  359. If left unset or set to `None`, this will use the predefined model maximum length if a maximum length
  360. is required by one of the truncation/padding parameters. If the model has no specific maximum input
  361. length (like XLNet) truncation/padding to a maximum length will be deactivated.
  362. """,
  363. "type": "int",
  364. }
  365. stride = {
  366. "description": """
  367. If set to a number along with `max_length`, the overflowing tokens returned when
  368. `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence
  369. returned to provide some overlap between truncated and overflowing sequences. The value of this
  370. argument defines the number of overlapping tokens.
  371. """,
  372. "type": "int",
  373. }
  374. pad_to_multiple_of = {
  375. "description": """
  376. If set will pad the sequence to a multiple of the provided value. Requires `padding` to be activated.
  377. This is especially useful to enable using Tensor Cores on NVIDIA hardware with compute capability
  378. `>= 7.5` (Volta).
  379. """,
  380. "type": "int",
  381. }
  382. return_token_type_ids = {
  383. "description": """
  384. Whether to return token type IDs. If left to the default, will return the token type IDs according to
  385. the specific tokenizer's default, defined by the `return_outputs` attribute.
  386. [What are token type IDs?](../glossary#token-type-ids)
  387. """,
  388. "type": "bool",
  389. }
  390. return_attention_mask = {
  391. "description": """
  392. Whether to return the attention mask. If left to the default, will return the attention mask according
  393. to the specific tokenizer's default, defined by the `return_outputs` attribute.
  394. [What are attention masks?](../glossary#attention-mask)
  395. """,
  396. "type": "bool",
  397. }
  398. return_overflowing_tokens = {
  399. "description": """
  400. Whether or not to return overflowing token sequences. If a pair of sequences of input ids (or a batch
  401. of pairs) is provided with `truncation_strategy = longest_first` or `True`, an error is raised instead
  402. of returning overflowing tokens.
  403. """,
  404. "type": "bool",
  405. }
  406. return_special_tokens_mask = {
  407. "description": """
  408. Whether or not to return special tokens mask information.
  409. """,
  410. "type": "bool",
  411. }
  412. return_offsets_mapping = {
  413. "description": """
  414. Whether or not to return `(char_start, char_end)` for each token.
  415. This is only available on fast tokenizers inheriting from [`PreTrainedTokenizerFast`], if using
  416. Python's tokenizer, this method will raise `NotImplementedError`.
  417. """,
  418. "type": "bool",
  419. }
  420. return_length = {
  421. "description": """
  422. Whether or not to return the lengths of the encoded inputs.
  423. """,
  424. "type": "bool",
  425. }
  426. verbose = {
  427. "description": """
  428. Whether or not to print more information and warnings.
  429. """,
  430. "type": "bool",
  431. }
  432. text_pair = {
  433. "description": """
  434. Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using
  435. the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
  436. method).
  437. """,
  438. "type": "str, list[str] or list[int]",
  439. }
  440. text_target = {
  441. "description": """
  442. The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a
  443. list of strings (pretokenized string). If you pass pretokenized input, set `is_split_into_words=True`
  444. to avoid ambiguity with batched inputs.
  445. """,
  446. "type": "str, list[str] or list[list[str]]",
  447. }
  448. text_pair_target = {
  449. "description": """
  450. The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a
  451. list of strings (pretokenized string). If you pass pretokenized input, set `is_split_into_words=True`
  452. to avoid ambiguity with batched inputs.
  453. """,
  454. "type": "str, list[str] or list[list[str]]",
  455. }
  456. is_split_into_words = {
  457. "description": """
  458. Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the
  459. tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace)
  460. which it will tokenize. This is useful for NER or token classification.
  461. """,
  462. "type": "bool",
  463. }
  464. boxes = {
  465. "description": """
  466. Word-level bounding boxes. Each bounding box should be normalized to be on a 0-1000 scale.
  467. """,
  468. "type": "list[list[int]] or list[list[list[int]]]",
  469. }
  470. word_labels = {
  471. "description": """
  472. Word-level integer labels (for token classification tasks such as FUNSD, CORD).
  473. """,
  474. "type": "list[int] or list[list[int]]",
  475. }
  476. # Used for the **kwargs summary line when unrolling typed kwargs (key: "__kwargs__")
  477. __kwargs__ = {
  478. "description": """
  479. Additional processing options for each modality (text, images, videos, audio). Model-specific parameters
  480. are listed above; see the TypedDict class for the complete list of supported arguments.
  481. """,
  482. "shape": None,
  483. }
  484. class ConfigArgs:
  485. output_hidden_states = {
  486. "description": """
  487. Whether or not the model should return all hidden-states.
  488. """,
  489. }
  490. chunk_size_feed_forward = {
  491. "description": """
  492. The `dtype` of the weights. This attribute can be used to initialize the model to a non-default `dtype`
  493. (which is normally `float32`) and thus allow for optimal storage allocation. For example, if the saved
  494. model is `float16`, ideally we want to load it back using the minimal amount of memory needed to load
  495. `float16` weights.
  496. """,
  497. }
  498. dtype = {
  499. "description": """
  500. The chunk size of all feed forward layers in the residual attention blocks. A chunk size of `0` means that
  501. the feed forward layer is not chunked. A chunk size of n means that the feed forward layer processes `n` <
  502. sequence_length embeddings at a time. For more information on feed forward chunking, see [How does Feed
  503. Forward Chunking work?](../glossary.html#feed-forward-chunking).
  504. """,
  505. }
  506. id2label = {
  507. "description": """
  508. A map from index (for instance prediction index, or target index) to label.
  509. """,
  510. }
  511. label2id = {
  512. "description": """
  513. A map from label to index for the model.
  514. """,
  515. }
  516. problem_type = {
  517. "description": """
  518. Problem type for `XxxForSequenceClassification` models. Can be one of `"regression"`,
  519. `"single_label_classification"` or `"multi_label_classification"`.
  520. """,
  521. }
  522. tokenizer_class = {
  523. "description": """
  524. The class name of model's tokenizer.
  525. """,
  526. }
  527. vocab_size = {
  528. "description": """
  529. Vocabulary size of the model. Defines the number of different tokens that can be represented by the `input_ids`.
  530. """,
  531. }
  532. hidden_size = {
  533. "description": """
  534. Dimension of the hidden representations.
  535. """,
  536. }
  537. intermediate_size = {
  538. "description": """
  539. Dimension of the MLP representations.
  540. """,
  541. }
  542. head_dim = {
  543. "description": """
  544. The attention head dimension. If None, it will default to hidden_size // num_attention_heads
  545. """
  546. }
  547. num_hidden_layers = {
  548. "description": """
  549. Number of hidden layers in the Transformer decoder.
  550. """,
  551. }
  552. num_attention_heads = {
  553. "description": """
  554. Number of attention heads for each attention layer in the Transformer decoder.
  555. """,
  556. }
  557. num_key_value_heads = {
  558. "description": """
  559. This is the number of key_value heads that should be used to implement Grouped Query Attention. If
  560. `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
  561. `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
  562. converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
  563. by meanpooling all the original heads within that group. For more details, check out [this
  564. paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
  565. `num_attention_heads`.
  566. """,
  567. }
  568. hidden_act = {
  569. "description": """
  570. The non-linear activation function (function or string) in the decoder. For example, `"gelu"`,
  571. `"relu"`, `"silu"`, etc.
  572. """,
  573. }
  574. max_position_embeddings = {
  575. "description": """
  576. The maximum sequence length that this model might ever be used with.
  577. """,
  578. }
  579. initializer_range = {
  580. "description": """
  581. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  582. """,
  583. }
  584. rms_norm_eps = {
  585. "description": """
  586. The epsilon used by the rms normalization layers.
  587. """,
  588. }
  589. use_cache = {
  590. "description": """
  591. Whether or not the model should return the last key/values attentions (not used by all models). Only
  592. relevant if `config.is_decoder=True` or when the model is a decoder-only generative model.
  593. """,
  594. }
  595. rope_parameters = {
  596. "description": """
  597. Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
  598. a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
  599. with longer `max_position_embeddings`.
  600. """,
  601. }
  602. attention_bias = {
  603. "description": """
  604. Whether to use a bias in the query, key, value and output projection layers during self-attention.
  605. """,
  606. }
  607. mlp_bias = {
  608. "description": """
  609. Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
  610. """,
  611. }
  612. attention_dropout = {
  613. "description": """
  614. The dropout ratio for the attention probabilities.
  615. """,
  616. }
  617. pretraining_tp = {
  618. "description": """
  619. Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
  620. document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
  621. understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
  622. results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
  623. """,
  624. }
  625. pad_token_id = {
  626. "description": """
  627. Token id used for padding in the vocabulary.
  628. """,
  629. }
  630. eos_token_id = {
  631. "description": """
  632. Token id used for end-of-stream in the vocabulary.
  633. """,
  634. }
  635. bos_token_id = {
  636. "description": """
  637. Token id used for beginning-of-stream in the vocabulary.
  638. """,
  639. }
  640. sep_token_id = {
  641. "description": """
  642. Token id used for separator in the vocabulary.
  643. """,
  644. }
  645. cls_token_id = {
  646. "description": """
  647. Token id used for CLS in the vocabulary.
  648. """,
  649. }
  650. tie_word_embeddings = {
  651. "description": """
  652. Whether to tie weight embeddings according to model's `tied_weights_keys` mapping.
  653. """,
  654. }
  655. d_model = {
  656. "description": """
  657. Size of the encoder layers and the pooler layer.
  658. """,
  659. }
  660. d_kv = {
  661. "description": """
  662. Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will
  663. be defined as `num_heads * d_kv`.
  664. """,
  665. }
  666. num_decoder_layers = {
  667. "description": """
  668. Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set.
  669. """,
  670. }
  671. num_encoder_layers = {
  672. "description": """
  673. Number of hidden layers in the Transformer encoder. Will use the same value as `num_layers` if not set.
  674. """,
  675. }
  676. dropout_rate = {
  677. "description": """
  678. The ratio for all dropout layers.
  679. """,
  680. }
  681. classifier_dropout = {
  682. "description": """
  683. The dropout ratio for classifier.
  684. """,
  685. }
  686. layer_norm_eps = {
  687. "description": """
  688. The epsilon used by the layer normalization layers.
  689. """,
  690. }
  691. initializer_factor = {
  692. "description": """
  693. A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
  694. testing).
  695. """,
  696. }
  697. encoder_attention_heads = {
  698. "description": """
  699. Number of attention heads for each attention layer in the Transformer encoder.
  700. """,
  701. }
  702. decoder_attention_heads = {
  703. "description": """
  704. Number of attention heads for each attention layer in the Transformer decoder.
  705. """,
  706. }
  707. decoder_ffn_dim = {
  708. "description": """
  709. Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
  710. """,
  711. }
  712. encoder_ffn_dim = {
  713. "description": """
  714. Dimensionality of the "intermediate" (often named feed-forward) layer in encoder.
  715. """,
  716. }
  717. activation_dropout = {
  718. "description": """
  719. The dropout ratio for activations inside the fully connected layer.
  720. """,
  721. }
  722. encoder_layerdrop = {
  723. "description": """
  724. The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556)
  725. for more details.
  726. """,
  727. }
  728. decoder_layerdrop = {
  729. "description": """
  730. The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556)
  731. for more details.
  732. """,
  733. }
  734. scale_embedding = {
  735. "description": """
  736. Whether to scale embeddings by dividing by sqrt(d_model).
  737. """,
  738. }
  739. forced_eos_token_id = {
  740. "description": """
  741. The id of the token to force as the last generated token when `max_length` is reached. Usually set to
  742. `eos_token_id`.
  743. """,
  744. }
  745. moe_intermediate_size = {
  746. "description": """
  747. Intermediate size of the routed expert MLPs.
  748. """,
  749. }
  750. num_experts = {
  751. "description": """
  752. Number of routed experts in MoE layers.
  753. """,
  754. }
  755. num_experts_per_tok = {
  756. "description": """
  757. Number of experts to route each token to. This is the top-k value for the token-choice routing.
  758. """,
  759. }
  760. num_shared_experts = {
  761. "description": """
  762. Number of shared experts that are always activated for all tokens.
  763. """,
  764. }
  765. layer_types = {
  766. "description": """
  767. A list that explicitly maps each layer index with its layer type. If not provided, it will be automatically
  768. generated based on config values.
  769. """,
  770. }
  771. norm_topk_prob = {
  772. "description": """
  773. Whether to normalize the weights of the routed experts.
  774. """,
  775. }
  776. topk_group = {
  777. "description": """
  778. Number of selected groups for each token (for each token, ensuring the selected experts is only within `topk_group` groups).
  779. """,
  780. }
  781. qk_rope_head_dim = {
  782. "description": """
  783. Dimension of the query/key heads that use rotary position embeddings.
  784. """,
  785. }
  786. v_head_dim = {
  787. "description": """
  788. Dimension of the value heads.
  789. """,
  790. }
  791. qk_nope_head_dim = {
  792. "description": """
  793. Dimension of the query/key heads that don't use rotary position embeddings.
  794. """,
  795. }
  796. kv_lora_rank = {
  797. "description": """
  798. Rank of the LoRA matrices for key and value projections.
  799. """,
  800. }
  801. q_lora_rank = {
  802. "description": """
  803. Rank of the LoRA matrices for query projections.
  804. """,
  805. }
  806. routed_scaling_factor = {
  807. "description": """
  808. Scaling factor or routed experts.
  809. """,
  810. }
  811. n_routed_experts = {
  812. "description": """
  813. Number of routed experts.
  814. """,
  815. }
  816. n_shared_experts = {
  817. "description": """
  818. Number of shared experts.
  819. """,
  820. }
  821. vision_config = {
  822. "description": """
  823. The config object or dictionary of the vision backbone.
  824. """,
  825. }
  826. text_config = {
  827. "description": """
  828. The config object or dictionary of the text backbone.
  829. """,
  830. }
  831. projector_hidden_act = {
  832. "description": """
  833. The activation function used by the multimodal projector.
  834. """,
  835. }
  836. vision_feature_select_strategy = {
  837. "description": """
  838. The feature selection strategy used to select the vision feature from the vision backbone.
  839. """,
  840. }
  841. vision_feature_layer = {
  842. "description": """
  843. The index of the layer to select the vision feature. If multiple indices are provided,
  844. the vision feature of the corresponding indices will be concatenated to form the
  845. vision features.
  846. """,
  847. }
  848. multimodal_projector_bias = {
  849. "description": """
  850. Whether to use bias in the multimodal projector.
  851. """,
  852. }
  853. image_token_id = {
  854. "description": """
  855. The image token index used as a placeholder for input images.
  856. """,
  857. }
  858. video_token_id = {
  859. "description": """
  860. The video token index used as a placeholder for input videos.
  861. """,
  862. }
  863. audio_token_id = {
  864. "description": """
  865. The audio token index used as a placeholder for input audio.
  866. """,
  867. }
  868. image_seq_length = {
  869. "description": """
  870. Sequence length of one image embedding.
  871. """,
  872. }
  873. video_seq_length = {
  874. "description": """
  875. Sequence length of one video embedding.
  876. """,
  877. }
  878. add_cross_attention = {
  879. "description": """
  880. Whether cross-attention layers should be added to the model.
  881. """,
  882. }
  883. is_decoder = {
  884. "description": """
  885. Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
  886. """,
  887. }
  888. sliding_window = {
  889. "description": """
  890. Sliding window attention window size. If `None`, no sliding window is applied.
  891. """,
  892. }
  893. use_sliding_window = {
  894. "description": """
  895. Whether to use sliding window attention.
  896. """,
  897. }
  898. shared_expert_intermediate_size = {
  899. "description": """
  900. Intermediate size of the shared expert MLPs.
  901. """,
  902. }
  903. decoder_sparse_step = {
  904. "description": """
  905. The frequency of adding a sparse MoE layer. The default is 1, which means all decoder layers are sparse MoE.
  906. """,
  907. }
  908. output_router_logits = {
  909. "description": """
  910. Whether or not the router logits should be returned by the model. Enabling this will also allow the model
  911. to output the auxiliary loss, including load balancing loss and router z-loss.
  912. """,
  913. }
  914. router_aux_loss_coef = {
  915. "description": """
  916. Auxiliary load balancing loss coefficient. Used to penalize uneven expert routing in MoE models.
  917. """,
  918. }
  919. out_indices = {
  920. "description": """
  921. Indices of the intermediate hidden states (feature maps) to return from the backbone. Each index
  922. corresponds to one stage of the model.
  923. """,
  924. }
  925. out_features = {
  926. "description": """
  927. Names of the intermediate hidden states (feature maps) to return from the backbone. One of `"stem"`,
  928. `"stage1"`, `"stage2"`, etc.
  929. """,
  930. }
  931. image_size = {
  932. "description": """
  933. The size (resolution) of each image.
  934. """,
  935. }
  936. patch_size = {
  937. "description": """
  938. The size (resolution) of each patch.
  939. """,
  940. }
  941. num_channels = {
  942. "description": """
  943. The number of input channels.
  944. """,
  945. }
  946. num_mel_bins = {
  947. "description": """
  948. Number of mel features used per input frame. Should correspond to the value used in the
  949. `AutoFeatureExtractor` class.
  950. """,
  951. }
  952. sampling_rate = {
  953. "description": """
  954. The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
  955. """,
  956. }
  957. hidden_dropout = {
  958. "description": """
  959. The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
  960. """,
  961. }
  962. mlp_ratio = {
  963. "description": """
  964. Ratio of the MLP hidden dim to the embedding dim.
  965. """,
  966. }
  967. qkv_bias = {
  968. "description": """
  969. Whether to add a bias to the queries, keys and values.
  970. """,
  971. }
  972. n_embd = {
  973. "description": """
  974. Dimensionality of the embeddings and hidden states.
  975. """,
  976. }
  977. resid_pdrop = {
  978. "description": """
  979. The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
  980. """,
  981. }
  982. embd_pdrop = {
  983. "description": """
  984. The dropout ratio for the embeddings.
  985. """,
  986. }
  987. clip_qkv = {
  988. "description": """
  989. If not `None`, cap the absolute value of the query, key, and value tensors to this value.
  990. """,
  991. }
  992. type_vocab_size = {
  993. "description": """
  994. The vocabulary size of the `token_type_ids`.
  995. """,
  996. }
  997. audio_config = {
  998. "description": """
  999. The config object or dictionary of the audio backbone.
  1000. """,
  1001. }
  1002. layerdrop = {
  1003. "description": """
  1004. The LayerDrop probability. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556) for
  1005. more details.
  1006. """,
  1007. }
  1008. expert_capacity = {
  1009. "description": """
  1010. The number of tokens that each expert can process. If `None`, `expert_capacity` will be set to
  1011. `(sequence_length / num_experts) * capacity_factor`.
  1012. """,
  1013. }
  1014. decoder_start_token_id = {
  1015. "description": """
  1016. If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
  1017. """,
  1018. }
  1019. is_encoder_decoder = {
  1020. "description": """
  1021. Whether the model is used as an encoder/decoder or not.
  1022. """,
  1023. }
  1024. num_codebooks = {
  1025. "description": """
  1026. The number of parallel codebooks used by the model.
  1027. """,
  1028. }
  1029. codebook_dim = {
  1030. "description": """
  1031. Dimensionality of each codebook embedding vector.
  1032. """,
  1033. }
  1034. hidden_sizes = {
  1035. "description": """
  1036. Dimensionality (hidden size) at each stage of the model.
  1037. """,
  1038. }
  1039. depths = {
  1040. "description": """
  1041. Depth of each layer in the Transformer.
  1042. """,
  1043. }
  1044. patch_sizes = {
  1045. "description": """
  1046. Patch size at each stage of the model.
  1047. """,
  1048. }
  1049. strides = {
  1050. "description": """
  1051. Stride at each stage of the model.
  1052. """,
  1053. }
  1054. router_jitter_noise = {
  1055. "description": """
  1056. Amount of noise to add to the router logits during training for better load balancing.
  1057. """,
  1058. }
  1059. num_local_experts = {
  1060. "description": """
  1061. Number of local experts on each device. `num_experts` should be divisible by `num_local_experts`.
  1062. """,
  1063. }
  1064. qk_layernorm = {
  1065. "description": """
  1066. Whether to use query-key normalization in the attention.
  1067. """,
  1068. }
  1069. backbone_config = {
  1070. "description": """
  1071. The configuration of the backbone model.
  1072. """,
  1073. }
  1074. no_object_weight = {
  1075. "description": """
  1076. Relative classification weight of the no-object class in the object detection loss.
  1077. """,
  1078. }
  1079. class_weight = {
  1080. "description": """
  1081. Relative weight of the classification error in the Hungarian matching cost.
  1082. """,
  1083. }
  1084. mask_weight = {
  1085. "description": """
  1086. Relative weight of the focal loss in the panoptic segmentation loss.
  1087. """,
  1088. }
  1089. dice_weight = {
  1090. "description": """
  1091. Relative weight of the dice loss in the panoptic segmentation loss.
  1092. """,
  1093. }
  1094. class_cost = {
  1095. "description": """
  1096. Relative weight of the classification error in the Hungarian matching cost.
  1097. """,
  1098. }
  1099. bbox_cost = {
  1100. "description": """
  1101. Relative weight of the L1 bounding box error in the Hungarian matching cost.
  1102. """,
  1103. }
  1104. giou_cost = {
  1105. "description": """
  1106. Relative weight of the generalized IoU loss in the Hungarian matching cost.
  1107. """,
  1108. }
  1109. focal_alpha = {
  1110. "description": """
  1111. Alpha parameter in the focal loss.
  1112. """,
  1113. }
  1114. mask_loss_coefficient = {
  1115. "description": """
  1116. Relative weight of the focal loss in the panoptic segmentation loss.
  1117. """,
  1118. }
  1119. giou_loss_coefficient = {
  1120. "description": """
  1121. Relative weight of the generalized IoU loss in the panoptic segmentation loss.
  1122. """,
  1123. }
  1124. bbox_loss_coefficient = {
  1125. "description": """
  1126. Relative weight of the L1 bounding box loss in the panoptic segmentation loss.
  1127. """,
  1128. }
  1129. cls_loss_coefficient = {
  1130. "description": """
  1131. Relative weight of the classification loss in the panoptic segmentation loss.
  1132. """,
  1133. }
  1134. dice_loss_coefficient = {
  1135. "description": """
  1136. Relative weight of the dice loss in the panoptic segmentation loss.
  1137. """,
  1138. }
  1139. semantic_loss_ignore_index = {
  1140. "description": """
  1141. The index that is ignored by the loss function of the semantic segmentation model.
  1142. """,
  1143. }
  1144. projection_dim = {
  1145. "description": """
  1146. Dimensionality of text and vision projection layers.
  1147. """,
  1148. }
  1149. logit_scale_init_value = {
  1150. "description": """
  1151. The initial value of the *logit_scale* parameter.
  1152. """,
  1153. }
  1154. num_dense_layers = {
  1155. "description": """
  1156. Number of initial dense layers before MoE layers begin. Layers with index < num_dense_layers will use
  1157. standard dense MLPs instead of MoE.
  1158. """,
  1159. }
  1160. drop_path_rate = {
  1161. "description": """
  1162. Drop path rate for the patch fusion.
  1163. """,
  1164. }
  1165. vq_config = {
  1166. "description": """
  1167. Configuration dict of the vector quantize module.
  1168. """,
  1169. }
  1170. num_embeddings = {
  1171. "description": """
  1172. Number of codebook embeddings.
  1173. """,
  1174. }
  1175. double_latent = {
  1176. "description": """
  1177. Whether to use double z channels.
  1178. """,
  1179. }
  1180. latent_channels = {
  1181. "description": """
  1182. Number of channels for the latent space.
  1183. """,
  1184. }
  1185. qformer_config = {
  1186. "description": """
  1187. Configuration dict of the Q-Former module.
  1188. """,
  1189. }
  1190. conv_kernel_size = {
  1191. "description": """
  1192. The size of the convolutional kernel.
  1193. """,
  1194. }
  1195. output_stride = {
  1196. "description": """
  1197. The ratio between the spatial resolution of the input and output feature maps.
  1198. """,
  1199. }
  1200. depth_multiplier = {
  1201. "description": """
  1202. Shrinks or expands the number of channels in each layer. This is sometimes also called "alpha" or "width multiplier".
  1203. """,
  1204. }
  1205. use_absolute_position_embeddings = {
  1206. "description": """
  1207. Whether to use absolute position embeddings.
  1208. """,
  1209. }
  1210. use_relative_position_bias = {
  1211. "description": """
  1212. Whether to use relative position bias in the self-attention layers.
  1213. """,
  1214. }
  1215. layer_scale_init_value = {
  1216. "description": """
  1217. Scale to use in the self-attention layers. 0.1 for base, 1e-6 for large. Set 0 to disable layer scale.
  1218. """,
  1219. }
  1220. vlm_config = {
  1221. "description": """
  1222. The config object or dictionary of the vision-language backbone.
  1223. """,
  1224. }
  1225. init_xavier_std = {
  1226. "description": """
  1227. The scaling factor used for the Xavier initialization of the cross-attention weights.
  1228. """,
  1229. }
  1230. auxiliary_loss = {
  1231. "description": """
  1232. Whether auxiliary decoding losses (losses at each decoder layer) are to be used.
  1233. """,
  1234. }
  1235. encoder_config = {
  1236. "description": """
  1237. The config object or dictionary of the encoder backbone.
  1238. """,
  1239. }
  1240. decoder_config = {
  1241. "description": """
  1242. The config object or dictionary of the decoder backbone.
  1243. """,
  1244. }
  1245. embedding_multiplier = {
  1246. "description": """
  1247. Scaling factor applied to the word embeddings. Used to scale the embeddings relative to the hidden size.
  1248. """,
  1249. }
  1250. logits_scaling = {
  1251. "description": """
  1252. Scaling factor applied to the output logits before computing the probability distribution.
  1253. """,
  1254. }
  1255. residual_multiplier = {
  1256. "description": """
  1257. Scaling factor applied to the residual connections.
  1258. """,
  1259. }
  1260. attention_multiplier = {
  1261. "description": """
  1262. Scaling factor applied to the attention weights.
  1263. """,
  1264. }
  1265. classifier_activation = {
  1266. "description": """
  1267. The activation function for the classification head.
  1268. """,
  1269. }
  1270. return_dict = {
  1271. "description": """
  1272. Whether to return a `ModelOutput` (dataclass) instead of a plain tuple.
  1273. """,
  1274. }
  1275. router_z_loss_coef = {
  1276. "description": """
  1277. Coefficient for the router z-loss, which penalizes large router logits to improve training stability.
  1278. """,
  1279. }
  1280. final_logit_softcapping = {
  1281. "description": """
  1282. Soft-capping value applied to the final logits before computing the probability distribution. Logits are
  1283. scaled by `tanh(logit / cap) * cap`.
  1284. """,
  1285. }
  1286. cross_attention_hidden_size = {
  1287. "description": """
  1288. Hidden size of the encoder outputs projected into the cross-attention key/value space of the decoder. Used
  1289. when the encoder and decoder have different hidden sizes.
  1290. """,
  1291. }
  1292. input_dim = {
  1293. "description": """
  1294. Dimensionality of the input acoustic features (e.g., number of mel-filterbank channels).
  1295. """,
  1296. }
  1297. use_auxiliary_loss = {
  1298. "description": """
  1299. Whether to calculate loss using intermediate predictions from transformer decoder.
  1300. """,
  1301. }
  1302. batch_norm_eps = {
  1303. "description": """
  1304. The epsilon used by the batch normalization layers.
  1305. """,
  1306. }
  1307. max_window_layers = {
  1308. "description": """
  1309. The number of layers using full attention. The first `max_window_layers` layers will use full attention, while any
  1310. additional layer afterwards will use SWA (Sliding Window Attention).
  1311. """,
  1312. }
  1313. ctc_loss_reduction = {
  1314. "description": """
  1315. Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training.
  1316. """,
  1317. }
  1318. mask_feature_prob = {
  1319. "description": """
  1320. Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
  1321. masking procedure generates `mask_feature_prob*len(feature_axis)/mask_time_length` independent masks over
  1322. the axis. If reasoning from the probability of each feature vector to be chosen as the start of the vector
  1323. span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
  1324. may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment` is
  1325. `True`.
  1326. """,
  1327. }
  1328. eos_coefficient = {
  1329. "description": """
  1330. Relative classification weight of the 'no-object' class in the object detection loss.
  1331. """,
  1332. }
  1333. num_labels = {
  1334. "description": """
  1335. Number of labels to use in the last layer added to the model, typically for a classification task.
  1336. """,
  1337. }
  1338. depth = {
  1339. "description": """
  1340. Number of Transformer layers in the vision encoder.
  1341. """,
  1342. }
  1343. temporal_patch_size = {
  1344. "description": """
  1345. Temporal patch size used in the 3D patch embedding for video inputs.
  1346. """,
  1347. }
  1348. spatial_merge_size = {
  1349. "description": """
  1350. The size of the spatial merge window used to reduce the number of visual tokens by merging neighboring patches.
  1351. """,
  1352. }
  1353. vision_start_token_id = {
  1354. "description": """
  1355. Token ID that marks the start of a visual segment in the multimodal input sequence.
  1356. """,
  1357. }
  1358. vision_end_token_id = {
  1359. "description": """
  1360. Token ID that marks the end of a visual segment in the multimodal input sequence.
  1361. """,
  1362. }
  1363. mamba_n_heads = {
  1364. "description": """
  1365. The number of mamba heads used in the v2 implementation.
  1366. """,
  1367. }
  1368. mamba_d_head = {
  1369. "description": """
  1370. Head embedding dimension size
  1371. """,
  1372. }
  1373. mamba_n_groups = {
  1374. "description": """
  1375. The number of the mamba groups used in the v2 implementation.
  1376. """,
  1377. }
  1378. mamba_d_conv = {
  1379. "description": """
  1380. The size of the mamba convolution kernel
  1381. """,
  1382. }
  1383. mamba_expand = {
  1384. "description": """
  1385. Expanding factor (relative to hidden_size) used to determine the mamba intermediate size
  1386. """,
  1387. }
  1388. mamba_chunk_size = {
  1389. "description": """
  1390. The chunks in which to break the sequence when doing prefill/training
  1391. """,
  1392. }
  1393. mamba_conv_bias = {
  1394. "description": """
  1395. Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block.
  1396. """,
  1397. }
  1398. mamba_proj_bias = {
  1399. "description": """
  1400. Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the mamba mixer block
  1401. """,
  1402. }
  1403. time_step_min = {
  1404. "description": """
  1405. Minimum `time_step` used to bound `dt_proj.bias`.
  1406. """,
  1407. }
  1408. time_step_max = {
  1409. "description": """
  1410. Maximum `time_step` used to bound `dt_proj.bias`.
  1411. """,
  1412. }
  1413. time_step_limit = {
  1414. "description": """
  1415. Accepted range of time step values for clamping.
  1416. """,
  1417. }
  1418. expand_ratio = {
  1419. "description": """
  1420. Expand ratio to set the output dimensions for the expansion
  1421. """,
  1422. }
  1423. state_size = {
  1424. "description": """
  1425. Size of the SSM state (latent state dimension) in the Mamba layers.
  1426. """,
  1427. }
  1428. time_step_rank = {
  1429. "description": """
  1430. Rank of the delta (time step) projection. Can be `"auto"` to set it automatically.
  1431. """,
  1432. }
  1433. time_step_floor = {
  1434. "description": """
  1435. Minimum allowed value for the discrete time step delta after softplus activation.
  1436. """,
  1437. }
  1438. time_step_scale = {
  1439. "description": """
  1440. Scale applied to the time step delta before discretization.
  1441. """,
  1442. }
  1443. time_step_init_scheme = {
  1444. "description": """
  1445. Initialization scheme for the time step delta. Can be `"random"` or `"uniform"`.
  1446. """,
  1447. }
  1448. mamba_d_ssm = {
  1449. "description": """
  1450. Inner state size of the SSM (state-space model) in the Mamba layers of FalconH1.
  1451. """,
  1452. }
  1453. mamba_norm_before_gate = {
  1454. "description": """
  1455. Whether to apply normalization before the gating mechanism in the Mamba mixer.
  1456. """,
  1457. }
  1458. mamba_rms_norm = {
  1459. "description": """
  1460. Whether to use RMS normalization in the Mamba layers (as opposed to standard LayerNorm).
  1461. """,
  1462. }
  1463. mamba_d_state = state_size
  1464. mamba_num_heads = mamba_n_heads
  1465. mamba_head_dim = mamba_d_head
  1466. num_input_channels = num_channels
  1467. audio_channels = num_channels
  1468. input_channels = num_channels
  1469. in_channels = num_channels
  1470. in_chans = num_channels
  1471. scale_attn_weights = scale_embedding
  1472. attention_probs_dropout_prob = attention_dropout
  1473. attn_pdrop = attention_dropout
  1474. attn_dropout = attention_dropout
  1475. dropout = dropout_rate
  1476. resid_dropout = resid_pdrop
  1477. residual_dropout = resid_pdrop
  1478. emb_pdrop = embd_pdrop
  1479. embed_dropout = embd_pdrop
  1480. embedding_dropout = embd_pdrop
  1481. hidden_dropout_prob = hidden_dropout
  1482. hidden_dropout_rate = hidden_dropout
  1483. classifier_dropout_prob = classifier_dropout
  1484. classifier_dropout_rate = classifier_dropout
  1485. dropout_prob = dropout
  1486. dropout_p = dropout
  1487. decoder_attention_dropout = attention_dropout
  1488. decoder_dropout = dropout
  1489. encoder_dropout = dropout
  1490. route_scale = routed_scaling_factor
  1491. activation_function = hidden_act
  1492. hidden_dim = hidden_size
  1493. num_decoder_attention_heads = decoder_attention_heads
  1494. num_encoder_attention_heads = encoder_attention_heads
  1495. decoder_num_heads = decoder_attention_heads
  1496. decoder_num_attention_heads = decoder_attention_heads
  1497. encoder_num_heads = encoder_attention_heads
  1498. encoder_num_attention_heads = encoder_attention_heads
  1499. encoder_layers = num_encoder_layers
  1500. decoder_layers = num_decoder_layers
  1501. decoder_num_layers = num_decoder_layers
  1502. encoder_num_layers = num_encoder_layers
  1503. d_ff = intermediate_size
  1504. dim_ff = intermediate_size
  1505. n_inner = intermediate_size
  1506. decoder_intermediate_size = intermediate_size
  1507. num_kv_heads = num_key_value_heads
  1508. num_layers = num_hidden_layers
  1509. n_layers = num_hidden_layers
  1510. n_layer = num_hidden_layers
  1511. layers = num_layers
  1512. encoder_num_hidden_layers = encoder_layers
  1513. decoder_num_hidden_layers = decoder_layers
  1514. num_heads = num_attention_heads
  1515. n_heads = num_attention_heads
  1516. n_head = num_attention_heads
  1517. hidden_activation = hidden_act
  1518. activation = hidden_act
  1519. mlp_hidden_act = hidden_act
  1520. d_head = head_dim
  1521. d_inner = intermediate_size
  1522. dim_head = head_dim
  1523. ffn_dim = intermediate_size
  1524. attention_heads = num_attention_heads
  1525. n_positions = max_position_embeddings
  1526. init_std = initializer_range
  1527. initializer_std = initializer_range
  1528. projector_bias = multimodal_projector_bias
  1529. image_token_index = image_token_id
  1530. video_token_index = video_token_id
  1531. audio_token_index = audio_token_id
  1532. embedding_size = n_embd
  1533. embed_dim = n_embd
  1534. projection_hidden_act = projector_hidden_act
  1535. layer_norm_epsilon = layer_norm_eps
  1536. rms_norm = rms_norm_eps
  1537. norm_eps = layer_norm_eps
  1538. eps = layer_norm_eps
  1539. norm_epsilon = layer_norm_eps
  1540. qk_layernorms = qk_layernorm
  1541. use_qk_norm = qk_layernorm
  1542. use_qkv_bias = qkv_bias
  1543. decoder_hidden_act = hidden_act
  1544. decoder_hidden_dim = hidden_size
  1545. decoder_hidden_size = hidden_size
  1546. encoder_hidden_dim = hidden_size
  1547. encoder_hidden_size = hidden_size
  1548. layer_scale_initial_scale = layer_scale_init_value
  1549. multi_modal_projector_bias = projector_bias
  1550. projector_hidden_size = projection_dim
  1551. projection_size = projection_dim
  1552. kernel_size = conv_kernel_size
  1553. conv_kernel = conv_kernel_size
  1554. use_absolute_embeddings = use_absolute_position_embeddings
  1555. use_abs_pos = use_absolute_position_embeddings
  1556. use_rel_pos = use_relative_position_bias
  1557. aux_loss_coef = router_aux_loss_coef
  1558. embedding_dimension = embed_dim
  1559. embedding_dim = embed_dim
  1560. emb_dim = embed_dim
  1561. n_codebooks = num_codebooks
  1562. codebook_size = num_codebooks
  1563. layers_block_type = layer_types
  1564. sample_rate = sampling_rate
  1565. text_vocab_size = vocab_size
  1566. class ModelArgs:
  1567. labels = {
  1568. "description": """
  1569. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1570. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1571. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1572. """,
  1573. "shape": "of shape `(batch_size, sequence_length)`",
  1574. }
  1575. num_logits_to_keep = {
  1576. "description": """
  1577. Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
  1578. `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
  1579. token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
  1580. """,
  1581. "shape": None,
  1582. }
  1583. input_ids = {
  1584. "description": """
  1585. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default.
  1586. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1587. [`PreTrainedTokenizer.__call__`] for details.
  1588. [What are input IDs?](../glossary#input-ids)
  1589. """,
  1590. "shape": "of shape `(batch_size, sequence_length)`",
  1591. }
  1592. input_values = {
  1593. "description": """
  1594. Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
  1595. into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
  1596. (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
  1597. To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
  1598. into a tensor of type `torch.FloatTensor`. See [`{processor_class}.__call__`] for details.
  1599. """,
  1600. "shape": "of shape `(batch_size, sequence_length)`",
  1601. }
  1602. attention_mask = {
  1603. "description": """
  1604. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1605. - 1 for tokens that are **not masked**,
  1606. - 0 for tokens that are **masked**.
  1607. [What are attention masks?](../glossary#attention-mask)
  1608. """,
  1609. "shape": "of shape `(batch_size, sequence_length)`",
  1610. }
  1611. decoder_attention_mask = {
  1612. "description": """
  1613. Mask to avoid performing attention on certain token indices. By default, a causal mask will be used, to
  1614. make sure the model can only look at previous inputs in order to predict the future.
  1615. """,
  1616. "shape": "of shape `(batch_size, target_sequence_length)`",
  1617. }
  1618. encoder_hidden_states = {
  1619. "description": """
  1620. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
  1621. if the model is configured as a decoder.
  1622. """,
  1623. "shape": "of shape `(batch_size, sequence_length, hidden_size)`",
  1624. }
  1625. encoder_attention_mask = {
  1626. "description": """
  1627. Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
  1628. the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
  1629. - 1 for tokens that are **not masked**,
  1630. - 0 for tokens that are **masked**.
  1631. """,
  1632. "shape": "of shape `(batch_size, sequence_length)`",
  1633. }
  1634. token_type_ids = {
  1635. "description": """
  1636. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`:
  1637. - 0 corresponds to a *sentence A* token,
  1638. - 1 corresponds to a *sentence B* token.
  1639. [What are token type IDs?](../glossary#token-type-ids)
  1640. """,
  1641. "shape": "of shape `(batch_size, sequence_length)`",
  1642. }
  1643. mm_token_type_ids = {
  1644. "description": """
  1645. Indices of input sequence tokens matching each modality. For example text (0), image (1), video (2).
  1646. Multimodal token type ids can be obtained using [`AutoProcessor`]. See [`ProcessorMixin.__call__`] for details.
  1647. """,
  1648. "shape": "of shape `(batch_size, sequence_length)`",
  1649. }
  1650. position_ids = {
  1651. "description": """
  1652. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`.
  1653. [What are position IDs?](../glossary#position-ids)
  1654. """,
  1655. "shape": "of shape `(batch_size, sequence_length)`",
  1656. }
  1657. past_key_values = {
  1658. "description": """
  1659. Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
  1660. blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
  1661. returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
  1662. Only [`~cache_utils.Cache`] instance is allowed as input, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  1663. If no `past_key_values` are passed, [`~cache_utils.DynamicCache`] will be initialized by default.
  1664. The model will output the same cache format that is fed as input.
  1665. If `past_key_values` are used, the user is expected to input only unprocessed `input_ids` (those that don't
  1666. have their past key value states given to this model) of shape `(batch_size, unprocessed_length)` instead of all `input_ids`
  1667. of shape `(batch_size, sequence_length)`.
  1668. """,
  1669. "shape": None,
  1670. }
  1671. inputs_embeds = {
  1672. "description": """
  1673. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  1674. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  1675. model's internal embedding lookup matrix.
  1676. """,
  1677. "shape": "of shape `(batch_size, sequence_length, hidden_size)`",
  1678. }
  1679. decoder_input_ids = {
  1680. "description": """
  1681. Indices of decoder input sequence tokens in the vocabulary.
  1682. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1683. [`PreTrainedTokenizer.__call__`] for details.
  1684. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1685. """,
  1686. "shape": "of shape `(batch_size, target_sequence_length)`",
  1687. }
  1688. decoder_inputs_embeds = {
  1689. "description": """
  1690. Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
  1691. representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
  1692. input (see `past_key_values`). This is useful if you want more control over how to convert
  1693. `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
  1694. If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
  1695. of `inputs_embeds`.
  1696. """,
  1697. "shape": "of shape `(batch_size, target_sequence_length, hidden_size)`",
  1698. }
  1699. use_cache = {
  1700. "description": """
  1701. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  1702. `past_key_values`).
  1703. """,
  1704. "shape": None,
  1705. }
  1706. output_attentions = {
  1707. "description": """
  1708. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  1709. tensors for more detail.
  1710. """,
  1711. "shape": None,
  1712. }
  1713. output_hidden_states = {
  1714. "description": """
  1715. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  1716. more detail.
  1717. """,
  1718. "shape": None,
  1719. }
  1720. return_dict = {
  1721. "description": """
  1722. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  1723. """,
  1724. "shape": None,
  1725. }
  1726. hidden_states = {
  1727. "description": """ input to the layer of shape `(batch, seq_len, embed_dim)""",
  1728. "shape": None,
  1729. }
  1730. interpolate_pos_encoding = {
  1731. "description": """
  1732. Whether to interpolate the pre-trained position encodings.
  1733. """,
  1734. "shape": None,
  1735. }
  1736. position_embeddings = {
  1737. "description": """
  1738. Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
  1739. with `head_dim` being the embedding dimension of each attention head.
  1740. """,
  1741. "shape": None,
  1742. }
  1743. config = {
  1744. "description": """
  1745. Model configuration class with all the parameters of the model. Initializing with a config file does not
  1746. load the weights associated with the model, only the configuration. Check out the
  1747. [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  1748. """,
  1749. "shape": None,
  1750. }
  1751. start_positions = {
  1752. "description": """
  1753. Labels for position (index) of the start of the labelled span for computing the token classification loss.
  1754. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  1755. are not taken into account for computing the loss.
  1756. """,
  1757. "shape": "of shape `(batch_size,)`",
  1758. }
  1759. end_positions = {
  1760. "description": """
  1761. Labels for position (index) of the end of the labelled span for computing the token classification loss.
  1762. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  1763. are not taken into account for computing the loss.
  1764. """,
  1765. "shape": "of shape `(batch_size,)`",
  1766. }
  1767. encoder_outputs = {
  1768. "description": """
  1769. Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
  1770. `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
  1771. hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
  1772. """,
  1773. "shape": None,
  1774. }
  1775. output_router_logits = {
  1776. "description": """
  1777. Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
  1778. should not be returned during inference.
  1779. """,
  1780. "shape": None,
  1781. }
  1782. logits_to_keep = {
  1783. "description": """
  1784. If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
  1785. `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
  1786. token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
  1787. If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
  1788. This is useful when using packed tensor format (single dimension for batch and sequence length).
  1789. """,
  1790. "shape": None,
  1791. }
  1792. pixel_values = {
  1793. "description": """
  1794. The tensors corresponding to the input images. Pixel values can be obtained using
  1795. [`{image_processor_class}`]. See [`{image_processor_class}.__call__`] for details ([`{processor_class}`] uses
  1796. [`{image_processor_class}`] for processing images).
  1797. """,
  1798. "shape": "of shape `(batch_size, num_channels, image_size, image_size)`",
  1799. }
  1800. pixel_values_videos = {
  1801. "description": """
  1802. The tensors corresponding to the input video. Pixel values for videos can be obtained using
  1803. [`{video_processor_class}`]. See [`{video_processor_class}.__call__`] for details ([`{processor_class}`] uses
  1804. [`{video_processor_class}`] for processing videos).
  1805. """,
  1806. "shape": "of shape `(batch_size, num_frames, num_channels, frame_size, frame_size)`",
  1807. }
  1808. vision_feature_layer = {
  1809. "description": """
  1810. The index of the layer to select the vision feature. If multiple indices are provided,
  1811. the vision feature of the corresponding indices will be concatenated to form the
  1812. vision features.
  1813. """,
  1814. "shape": None,
  1815. }
  1816. vision_feature_select_strategy = {
  1817. "description": """
  1818. The feature selection strategy used to select the vision feature from the vision backbone.
  1819. Can be one of `"default"` or `"full"`.
  1820. """,
  1821. "shape": None,
  1822. }
  1823. image_sizes = {
  1824. "description": """
  1825. The sizes of the images in the batch, being (height, width) for each image.
  1826. """,
  1827. "shape": "of shape `(batch_size, 2)`",
  1828. }
  1829. pixel_mask = {
  1830. "description": """
  1831. Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`:
  1832. - 1 for pixels that are real (i.e. **not masked**),
  1833. - 0 for pixels that are padding (i.e. **masked**).
  1834. [What are attention masks?](../glossary#attention-mask)
  1835. """,
  1836. "shape": "of shape `(batch_size, height, width)`",
  1837. }
  1838. input_features = {
  1839. "description": """
  1840. The tensors corresponding to the input audio features. Audio features can be obtained using
  1841. [`{feature_extractor_class}`]. See [`{feature_extractor_class}.__call__`] for details ([`{processor_class}`] uses
  1842. [`{feature_extractor_class}`] for processing audios).
  1843. """,
  1844. "shape": "of shape `(batch_size, sequence_length, feature_dim)`",
  1845. }
  1846. class ModelOutputArgs:
  1847. last_hidden_state = {
  1848. "description": """
  1849. Sequence of hidden-states at the output of the last layer of the model.
  1850. """,
  1851. "shape": "of shape `(batch_size, sequence_length, hidden_size)`",
  1852. }
  1853. past_key_values = {
  1854. "description": """
  1855. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  1856. Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
  1857. `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
  1858. input) to speed up sequential decoding.
  1859. """,
  1860. "shape": None,
  1861. "additional_info": "returned when `use_cache=True` is passed or when `config.use_cache=True`",
  1862. }
  1863. hidden_states = {
  1864. "description": """
  1865. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  1866. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  1867. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  1868. """,
  1869. "shape": None,
  1870. "additional_info": "returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`",
  1871. }
  1872. attentions = {
  1873. "description": """
  1874. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  1875. sequence_length)`.
  1876. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  1877. heads.
  1878. """,
  1879. "shape": None,
  1880. "additional_info": "returned when `output_attentions=True` is passed or when `config.output_attentions=True`",
  1881. }
  1882. pooler_output = {
  1883. "description": """
  1884. Last layer hidden-state after a pooling operation on the spatial dimensions.
  1885. """,
  1886. "shape": "of shape `(batch_size, hidden_size)`",
  1887. }
  1888. cross_attentions = {
  1889. "description": """
  1890. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  1891. sequence_length)`.
  1892. Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
  1893. weighted average in the cross-attention heads.
  1894. """,
  1895. "shape": None,
  1896. "additional_info": "returned when `output_attentions=True` is passed or when `config.output_attentions=True`",
  1897. }
  1898. decoder_hidden_states = {
  1899. "description": """
  1900. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  1901. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  1902. Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
  1903. """,
  1904. "shape": None,
  1905. "additional_info": "returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`",
  1906. }
  1907. decoder_attentions = {
  1908. "description": """
  1909. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  1910. sequence_length)`.
  1911. Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
  1912. self-attention heads.
  1913. """,
  1914. "shape": None,
  1915. "additional_info": "returned when `output_attentions=True` is passed or when `config.output_attentions=True`",
  1916. }
  1917. encoder_last_hidden_state = {
  1918. "description": """
  1919. Sequence of hidden-states at the output of the last layer of the encoder of the model.
  1920. """,
  1921. "shape": "of shape `(batch_size, sequence_length, hidden_size)`",
  1922. }
  1923. encoder_hidden_states = {
  1924. "description": """
  1925. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  1926. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  1927. Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
  1928. """,
  1929. "shape": None,
  1930. "additional_info": "returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`",
  1931. }
  1932. encoder_attentions = {
  1933. "description": """
  1934. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  1935. sequence_length)`.
  1936. Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
  1937. self-attention heads.
  1938. """,
  1939. "shape": None,
  1940. "additional_info": "returned when `output_attentions=True` is passed or when `config.output_attentions=True`",
  1941. }
  1942. router_logits = {
  1943. "description": """
  1944. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
  1945. Router logits of the model, useful to compute the auxiliary loss for Mixture of Experts models.
  1946. """,
  1947. "shape": None,
  1948. "additional_info": "returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`",
  1949. }
  1950. router_probs = {
  1951. "description": """
  1952. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
  1953. Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary
  1954. loss and the z_loss for Mixture of Experts models.
  1955. """,
  1956. "shape": None,
  1957. "additional_info": "returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`",
  1958. }
  1959. z_loss = {
  1960. "description": """
  1961. z_loss for the sparse modules.
  1962. """,
  1963. "shape": None,
  1964. "additional_info": "returned when `labels` is provided",
  1965. }
  1966. aux_loss = {
  1967. "description": """
  1968. aux_loss for the sparse modules.
  1969. """,
  1970. "shape": None,
  1971. "additional_info": "returned when `labels` is provided",
  1972. }
  1973. start_logits = {
  1974. "description": """
  1975. Span-start scores (before SoftMax).
  1976. """,
  1977. "shape": "of shape `(batch_size, sequence_length)`",
  1978. }
  1979. end_logits = {
  1980. "description": """
  1981. Span-end scores (before SoftMax).
  1982. """,
  1983. "shape": "of shape `(batch_size, sequence_length)`",
  1984. }
  1985. feature_maps = {
  1986. "description": """
  1987. Feature maps of the stages.
  1988. """,
  1989. "shape": "of shape `(batch_size, num_channels, height, width)`",
  1990. }
  1991. reconstruction = {
  1992. "description": """
  1993. Reconstructed / completed images.
  1994. """,
  1995. "shape": "of shape `(batch_size, num_channels, height, width)`",
  1996. }
  1997. spectrogram = {
  1998. "description": """
  1999. The predicted spectrogram.
  2000. """,
  2001. "shape": "of shape `(batch_size, sequence_length, num_bins)`",
  2002. }
  2003. predicted_depth = {
  2004. "description": """
  2005. Predicted depth for each pixel.
  2006. """,
  2007. "shape": "of shape `(batch_size, height, width)`",
  2008. }
  2009. sequences = {
  2010. "description": """
  2011. Sampled values from the chosen distribution.
  2012. """,
  2013. "shape": "of shape `(batch_size, num_samples, prediction_length)` or `(batch_size, num_samples, prediction_length, input_size)`",
  2014. }
  2015. params = {
  2016. "description": """
  2017. Parameters of the chosen distribution.
  2018. """,
  2019. "shape": "of shape `(batch_size, num_samples, num_params)`",
  2020. }
  2021. loc = {
  2022. "description": """
  2023. Shift values of each time series' context window which is used to give the model inputs of the same
  2024. magnitude and then used to shift back to the original magnitude.
  2025. """,
  2026. "shape": "of shape `(batch_size,)` or `(batch_size, input_size)`",
  2027. }
  2028. scale = {
  2029. "description": """
  2030. Scaling values of each time series' context window which is used to give the model inputs of the same
  2031. magnitude and then used to rescale back to the original magnitude.
  2032. """,
  2033. "shape": "of shape `(batch_size,)` or `(batch_size, input_size)`",
  2034. }
  2035. static_features = {
  2036. "description": """
  2037. Static features of each time series' in a batch which are copied to the covariates at inference time.
  2038. """,
  2039. "shape": "of shape `(batch_size, feature size)`",
  2040. }
  2041. embeddings = {
  2042. "description": """
  2043. Utterance embeddings used for vector similarity-based retrieval.
  2044. """,
  2045. "shape": "of shape `(batch_size, config.xvector_output_dim)`",
  2046. }
  2047. extract_features = {
  2048. "description": """
  2049. Sequence of extracted feature vectors of the last convolutional layer of the model.
  2050. """,
  2051. "shape": "of shape `(batch_size, sequence_length, conv_dim[-1])`",
  2052. }
  2053. projection_state = {
  2054. "description": """
  2055. Text embeddings before the projection layer, used to mimic the last hidden state of the teacher encoder.
  2056. """,
  2057. "shape": "of shape `(batch_size,config.project_dim)`",
  2058. }
  2059. image_hidden_states = {
  2060. "description": """
  2061. Image hidden states of the model produced by the vision encoder and after projecting the last hidden state.
  2062. """,
  2063. "shape": "of shape `(batch_size, num_images, sequence_length, hidden_size)`",
  2064. }
  2065. video_hidden_states = {
  2066. "description": """
  2067. Video hidden states of the model produced by the vision encoder and after projecting the last hidden state.
  2068. """,
  2069. "shape": "of shape `(batch_size * num_frames, num_images, sequence_length, hidden_size)`",
  2070. }
  2071. class ClassDocstring:
  2072. Config = r"""
  2073. This is the configuration class to store the configuration of a {model_base_class}. It is used to instantiate a {model_name}
  2074. model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
  2075. defaults will yield a similar configuration to that of the [{model_checkpoint}](https://huggingface.co/{model_checkpoint})
  2076. Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
  2077. documentation from [`PreTrainedConfig`] for more information.
  2078. """
  2079. PreTrainedModel = r"""
  2080. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  2081. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  2082. etc.)
  2083. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  2084. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  2085. and behavior.
  2086. """
  2087. Model = r"""
  2088. The bare {model_name} Model outputting raw hidden-states without any specific head on top.
  2089. """
  2090. ForPreTraining = r"""
  2091. The {model_name} Model with a specified pretraining head on top.
  2092. """
  2093. Decoder = r"""
  2094. The bare {model_name} Decoder outputting raw hidden-states without any specific head on top.
  2095. """
  2096. TextModel = r"""
  2097. The bare {model_name} Text Model outputting raw hidden-states without any specific head on to.
  2098. """
  2099. ForSequenceClassification = r"""
  2100. The {model_name} Model with a sequence classification/regression head on top e.g. for GLUE tasks.
  2101. """
  2102. ForQuestionAnswering = r"""
  2103. The {model_name} transformer with a span classification head on top for extractive question-answering tasks like
  2104. SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
  2105. """
  2106. ForMultipleChoice = r"""
  2107. The {model_name} Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
  2108. softmax) e.g. for RocStories/SWAG tasks.
  2109. """
  2110. ForMaskedLM = r"""
  2111. The {model_name} Model with a `language modeling` head on top."
  2112. """
  2113. ForTokenClassification = r"""
  2114. The {model_name} transformer with a token classification head on top (a linear layer on top of the hidden-states
  2115. output) e.g. for Named-Entity-Recognition (NER) tasks.
  2116. """
  2117. ForConditionalGeneration = r"""
  2118. The {model_name} Model for token generation conditioned on other modalities (e.g. image-text-to-text generation).
  2119. """
  2120. ForCausalLM = r"""
  2121. The {model_name} Model for causal language modeling.
  2122. """
  2123. Backbone = r"""
  2124. The {model_name} backbone.
  2125. """
  2126. ForImageClassification = r"""
  2127. The {model_name} Model with an image classification head on top e.g. for ImageNet.
  2128. """
  2129. ForSemanticSegmentation = r"""
  2130. The {model_name} Model with a semantic segmentation head on top e.g. for ADE20K, CityScapes.
  2131. """
  2132. ForAudioClassification = r"""
  2133. The {model_name} Model with an audio classification head on top (a linear layer on top of the pooled
  2134. output).
  2135. """
  2136. ForAudioFrameClassification = r"""
  2137. The {model_name} Model with a frame classification head on top for tasks like Speaker Diarization.
  2138. """
  2139. ForPrediction = r"""
  2140. The {model_name} Model with a distribution head on top for time-series forecasting.
  2141. """
  2142. WithProjection = r"""
  2143. The {model_name} Model with a projection layer on top (a linear layer on top of the pooled output).
  2144. """
  2145. class ClassAttrs:
  2146. # fmt: off
  2147. base_model_prefix = r"""
  2148. A string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
  2149. """
  2150. supports_gradient_checkpointing = r"""
  2151. Whether the model supports gradient checkpointing or not. Gradient checkpointing is a memory-saving technique that trades compute for memory, by storing only a subset of activations (checkpoints) and recomputing the activations that are not stored during the backward pass.
  2152. """
  2153. _no_split_modules = r"""
  2154. Layers of modules that should not be split across devices should be added to `_no_split_modules`. This can be useful for modules that contains skip connections or other operations that are not compatible with splitting the module across devices. Setting this attribute will enable the use of `device_map="auto"` in the `from_pretrained` method.
  2155. """
  2156. _skip_keys_device_placement = r"""
  2157. A list of keys to ignore when moving inputs or outputs between devices when using the `accelerate` library.
  2158. """
  2159. _supports_flash_attn = r"""
  2160. Whether the model's attention implementation supports FlashAttention.
  2161. """
  2162. _supports_sdpa = r"""
  2163. Whether the model's attention implementation supports SDPA (Scaled Dot Product Attention).
  2164. """
  2165. _supports_flex_attn = r"""
  2166. Whether the model's attention implementation supports FlexAttention.
  2167. """
  2168. _can_compile_fullgraph = r"""
  2169. Whether the model can `torch.compile` fullgraph without graph breaks. Models will auto-compile if this flag is set to `True`
  2170. in inference, if a compilable cache is used.
  2171. """
  2172. _supports_attention_backend = r"""
  2173. Whether the model supports attention interface functions. This flag signal that the model can be used as an efficient backend in TGI and vLLM.
  2174. """
  2175. _tied_weights_keys = r"""
  2176. A list of `state_dict` keys that are potentially tied to another key in the state_dict.
  2177. """
  2178. # fmt: on
  2179. ARGS_TO_IGNORE = {"self", "kwargs", "args", "deprecated_arguments"}
  2180. ARGS_TO_RENAME = {"_out_features": "out_features", "_out_indices": "out_indices"}
  2181. def get_indent_level(func):
  2182. # Use this instead of `inspect.getsource(func)` as getsource can be very slow
  2183. return (len(func.__qualname__.split(".")) - 1) * 4
  2184. def equalize_indent(docstring: str, indent_level: int) -> str:
  2185. """
  2186. Adjust the indentation of a docstring to match the specified indent level.
  2187. """
  2188. prefix = " " * indent_level
  2189. # Uses splitlines() (no keepends) to match previous behaviour that dropped
  2190. # any trailing newline via the old splitlines() + "\n".join() + textwrap.indent path.
  2191. return "\n".join(prefix + line.lstrip() if line.strip() else "" for line in docstring.splitlines())
  2192. def set_min_indent(docstring: str, indent_level: int) -> str:
  2193. """
  2194. Adjust the indentation of a docstring to match the specified indent level.
  2195. """
  2196. # Equivalent to textwrap.dedent + textwrap.indent but avoids the two regex
  2197. # passes that textwrap uses internally (one per call in dedent, one in indent).
  2198. lines = docstring.split("\n")
  2199. min_indent = min(
  2200. (len(line) - len(line.lstrip()) for line in lines if line.strip()),
  2201. default=0,
  2202. )
  2203. prefix = " " * indent_level
  2204. return "\n".join(prefix + line[min_indent:] if line.strip() else "" for line in lines)
  2205. def parse_shape(docstring):
  2206. match = _re_shape.search(docstring)
  2207. if match:
  2208. return " " + match.group(1)
  2209. return None
  2210. def parse_default(docstring):
  2211. match = _re_default.search(docstring)
  2212. if match:
  2213. return " " + match.group(1)
  2214. return None
  2215. def parse_docstring(docstring, max_indent_level=0, return_intro=False):
  2216. """
  2217. Parse the docstring to extract the Args section and return it as a dictionary.
  2218. The docstring is expected to be in the format:
  2219. Args:
  2220. arg1 (type):
  2221. Description of arg1.
  2222. arg2 (type):
  2223. Description of arg2.
  2224. # This function will also return the remaining part of the docstring after the Args section.
  2225. Returns:/Example:
  2226. ...
  2227. """
  2228. match = _re_example_or_return.search(docstring)
  2229. if match:
  2230. remainder_docstring = docstring[match.start() :]
  2231. docstring = docstring[: match.start()]
  2232. else:
  2233. remainder_docstring = ""
  2234. args_match = _re_args_section.search(docstring)
  2235. # still try to find args description in the docstring, if args are not preceded by "Args:"
  2236. docstring_intro = None
  2237. if args_match:
  2238. docstring_intro = docstring[: args_match.start()]
  2239. if docstring_intro.split("\n")[-1].strip() == '"""':
  2240. docstring_intro = "\n".join(docstring_intro.split("\n")[:-1])
  2241. if docstring_intro.split("\n")[0].strip() == 'r"""' or docstring_intro.split("\n")[0].strip() == '"""':
  2242. docstring_intro = "\n".join(docstring_intro.split("\n")[1:])
  2243. if docstring_intro.strip() == "":
  2244. docstring_intro = None
  2245. args_section = args_match.group(1).lstrip("\n") if args_match else docstring
  2246. if args_section.split("\n")[-1].strip() == '"""':
  2247. args_section = "\n".join(args_section.split("\n")[:-1])
  2248. if args_section.split("\n")[0].strip() == 'r"""' or args_section.split("\n")[0].strip() == '"""':
  2249. args_section = "\n".join(args_section.split("\n")[1:])
  2250. args_section = set_min_indent(args_section, 0)
  2251. params = {}
  2252. if args_section:
  2253. # Use the pre-compiled pattern (max_indent_level is always 0 at all call
  2254. # sites; if a non-zero value is ever needed, compile a fresh pattern).
  2255. if max_indent_level == 0:
  2256. param_pattern = _re_param
  2257. else:
  2258. param_pattern = re.compile(
  2259. # |--- Group 1 ---|| Group 2 ||- Group 3 -||---------- Group 4 ----------|
  2260. rf"^\s{{0,{max_indent_level}}}(\w+)\s*\(\s*([^, \)]*)(\s*.*?)\s*\)\s*:\s*((?:(?!\n^\s{{0,{max_indent_level}}}\w+\s*\().)*)",
  2261. re.DOTALL | re.MULTILINE,
  2262. )
  2263. for match in param_pattern.finditer(args_section):
  2264. param_name = match.group(1)
  2265. param_type = match.group(2)
  2266. additional_info = match.group(3)
  2267. optional = "optional" in additional_info
  2268. shape = parse_shape(additional_info)
  2269. default = parse_default(additional_info)
  2270. param_description = match.group(4).strip()
  2271. # indent the first line of param_description to 4 spaces:
  2272. param_description = " " * 4 + param_description
  2273. param_description = f"\n{param_description}"
  2274. params[param_name] = {
  2275. "type": param_type,
  2276. "description": param_description,
  2277. "optional": optional,
  2278. "shape": shape,
  2279. "default": default,
  2280. "additional_info": additional_info,
  2281. }
  2282. if params and remainder_docstring:
  2283. remainder_docstring = "\n" + remainder_docstring
  2284. remainder_docstring = set_min_indent(remainder_docstring, 0)
  2285. if return_intro:
  2286. return params, remainder_docstring, docstring_intro
  2287. return params, remainder_docstring
  2288. def contains_type(type_hint, target_type) -> tuple[bool, object | None]:
  2289. """
  2290. Check if a "nested" type hint contains a specific target type,
  2291. return the first-level type containing the target_type if found.
  2292. """
  2293. args = get_args(type_hint)
  2294. if args == ():
  2295. try:
  2296. return issubclass(type_hint, target_type), type_hint
  2297. except Exception:
  2298. return issubclass(type(type_hint), target_type), type_hint
  2299. found_type_tuple = [contains_type(arg, target_type)[0] for arg in args]
  2300. found_type = any(found_type_tuple)
  2301. if found_type:
  2302. type_hint = args[found_type_tuple.index(True)]
  2303. return found_type, type_hint
  2304. def get_model_name(obj):
  2305. """
  2306. Get the model name from the file path of the object.
  2307. """
  2308. path = inspect.getsourcefile(obj)
  2309. if path is None:
  2310. return None
  2311. if path.split(os.path.sep)[-3] != "models":
  2312. return None
  2313. file_name = path.split(os.path.sep)[-1]
  2314. model_name_lowercase_from_folder = path.split(os.path.sep)[-2]
  2315. model_name_lowercase_from_file = None
  2316. for file_type in AUTODOC_FILES:
  2317. start = file_type.split("*")[0]
  2318. end = file_type.split("*")[-1] if "*" in file_type else ""
  2319. if file_name.startswith(start) and file_name.endswith(end):
  2320. model_name_lowercase_from_file = file_name[len(start) : -len(end)]
  2321. break
  2322. if model_name_lowercase_from_file and model_name_lowercase_from_folder != model_name_lowercase_from_file:
  2323. from transformers.models.auto.configuration_auto import SPECIAL_MODEL_TYPE_TO_MODULE_NAME
  2324. if (
  2325. model_name_lowercase_from_file in SPECIAL_MODEL_TYPE_TO_MODULE_NAME
  2326. or model_name_lowercase_from_file.replace("_", "-") in SPECIAL_MODEL_TYPE_TO_MODULE_NAME
  2327. ):
  2328. return model_name_lowercase_from_file
  2329. return model_name_lowercase_from_folder
  2330. return model_name_lowercase_from_folder
  2331. def generate_processor_intro(cls) -> str:
  2332. """
  2333. Generate the intro docstring for a processor class based on its attributes.
  2334. Args:
  2335. cls: Processor class to generate intro for
  2336. Returns:
  2337. str: Generated intro text
  2338. """
  2339. class_name = cls.__name__
  2340. # Get attributes and their corresponding class names
  2341. attributes = cls.get_attributes()
  2342. if not attributes:
  2343. return ""
  2344. # Build list of component names and their classes
  2345. components = []
  2346. component_classes = []
  2347. for attr in attributes:
  2348. # Get the class name for this attribute
  2349. class_attr = f"{attr}_class"
  2350. # Format attribute name for display
  2351. attr_display = attr.replace("_", " ")
  2352. components.append(attr_display)
  2353. component_classes.append(f"[`{{{class_attr}}}`]")
  2354. if not components:
  2355. return ""
  2356. # Generate the intro text
  2357. if len(components) == 1:
  2358. components_text = f"a {components[0]}"
  2359. classes_text = component_classes[0]
  2360. classes_text_short = component_classes[0].replace("[`", "[`~")
  2361. elif len(components) == 2:
  2362. components_text = f"a {components[0]} and a {components[1]}"
  2363. classes_text = f"{component_classes[0]} and {component_classes[1]}"
  2364. classes_text_short = (
  2365. f"{component_classes[0].replace('[`', '[`~')} and {component_classes[1].replace('[`', '[`~')}"
  2366. )
  2367. else:
  2368. components_text = ", ".join(f"a {c}" for c in components[:-1]) + f", and a {components[-1]}"
  2369. classes_text = ", ".join(component_classes[:-1]) + f", and {component_classes[-1]}"
  2370. classes_short = [c.replace("[`", "[`~") for c in component_classes]
  2371. classes_text_short = ", ".join(classes_short[:-1]) + f", and {classes_short[-1]}"
  2372. intro = f"""Constructs a {class_name} which wraps {components_text} into a single processor.
  2373. [`{class_name}`] offers all the functionalities of {classes_text}. See the
  2374. {classes_text_short} for more information.
  2375. """
  2376. return intro
  2377. def get_placeholders_dict(placeholders: set[str], model_name: str) -> Mapping[str, str | None]:
  2378. """
  2379. Get the dictionary of placeholders for the given model name.
  2380. """
  2381. # import here to avoid circular import
  2382. from transformers.models import auto as auto_module
  2383. placeholders_dict = {}
  2384. for placeholder in placeholders:
  2385. # Infer placeholders from the model name and the auto modules
  2386. if placeholder in PLACEHOLDER_TO_AUTO_MODULE:
  2387. try:
  2388. place_holder_value = getattr(
  2389. getattr(auto_module, PLACEHOLDER_TO_AUTO_MODULE[placeholder][0]),
  2390. PLACEHOLDER_TO_AUTO_MODULE[placeholder][1],
  2391. ).get(model_name, None)
  2392. except ImportError:
  2393. # In case a library is not installed, we don't want to fail the docstring generation
  2394. place_holder_value = None
  2395. if place_holder_value is not None:
  2396. if isinstance(place_holder_value, (list, tuple)):
  2397. place_holder_value = (
  2398. place_holder_value[-1] if place_holder_value[-1] is not None else place_holder_value[0]
  2399. )
  2400. placeholders_dict[placeholder] = place_holder_value if place_holder_value is not None else placeholder
  2401. else:
  2402. placeholders_dict[placeholder] = placeholder
  2403. return placeholders_dict
  2404. def format_args_docstring(docstring: str, model_name: str) -> str:
  2405. """
  2406. Replaces placeholders such as {image_processor_class} in the docstring with the actual values,
  2407. deducted from the model name and the auto modules.
  2408. """
  2409. # first check if there are any placeholders in the docstring, if not return it as is
  2410. placeholders = set(_re_placeholders.findall(docstring))
  2411. if not placeholders:
  2412. return docstring
  2413. # get the placeholders dictionary for the given model name
  2414. placeholders_dict = get_placeholders_dict(placeholders, model_name)
  2415. # replace the placeholders in the docstring with the values from the placeholders_dict
  2416. for placeholder, value in placeholders_dict.items():
  2417. if isinstance(value, dict) and placeholder == "image_processor_class":
  2418. value = value.get("torchvision", value.get("pil", None))
  2419. if placeholder is not None:
  2420. docstring = docstring.replace(f"{{{placeholder}}}", value)
  2421. return docstring
  2422. def get_args_doc_from_source(args_classes: object | list[object]) -> dict:
  2423. if isinstance(args_classes, (list, tuple)):
  2424. return _merge_args_dicts(tuple(args_classes))
  2425. return args_classes.__dict__
  2426. @lru_cache(maxsize=16)
  2427. def _merge_args_dicts(args_classes_tuple: tuple) -> dict:
  2428. """Cached merger of args-doc dicts. The input classes are static so caching is safe."""
  2429. result = {}
  2430. for cls in args_classes_tuple:
  2431. result.update(cls.__dict__)
  2432. return result
  2433. def get_checkpoint_from_config_class(config_class):
  2434. checkpoint = None
  2435. # source code of `config_class`
  2436. # config_source = inspect.getsource(config_class)
  2437. config_source = config_class.__doc__
  2438. if not config_source:
  2439. return None
  2440. checkpoints = _re_checkpoint.findall(config_source)
  2441. # Each `checkpoint` is a tuple of a checkpoint name and a checkpoint link.
  2442. # For example, `('google-bert/bert-base-uncased', 'https://huggingface.co/google-bert/bert-base-uncased')`
  2443. for ckpt_name, ckpt_link in checkpoints:
  2444. # allow the link to end with `/`
  2445. ckpt_link = ckpt_link.removesuffix("/")
  2446. # verify the checkpoint name corresponds to the checkpoint link
  2447. ckpt_link_from_name = f"https://huggingface.co/{ckpt_name}"
  2448. if ckpt_link == ckpt_link_from_name:
  2449. checkpoint = ckpt_name
  2450. break
  2451. return checkpoint
  2452. def add_intro_docstring(func, class_name, indent_level=0):
  2453. intro_docstring = ""
  2454. if func.__name__ == "forward":
  2455. intro_docstring = rf"""The [`{class_name}`] forward method, overrides the `__call__` special method.
  2456. <Tip>
  2457. Although the recipe for forward pass needs to be defined within this function, one should call the [`Module`]
  2458. instance afterwards instead of this since the former takes care of running the pre and post processing steps while
  2459. the latter silently ignores them.
  2460. </Tip>
  2461. """
  2462. intro_docstring = equalize_indent(intro_docstring, indent_level + 4)
  2463. return intro_docstring
  2464. def _get_model_info(func, parent_class):
  2465. """
  2466. Extract model information from a function or its parent class.
  2467. Args:
  2468. func (`function`): The function to extract information from
  2469. parent_class (`class`): Optional parent class of the function
  2470. """
  2471. # import here to avoid circular import
  2472. from transformers.models import auto as auto_module
  2473. # Get model name from either parent class or function
  2474. if parent_class is not None:
  2475. model_name_lowercase = get_model_name(parent_class)
  2476. else:
  2477. model_name_lowercase = get_model_name(func)
  2478. # Normalize model name if needed
  2479. if model_name_lowercase and model_name_lowercase not in getattr(
  2480. getattr(auto_module, PLACEHOLDER_TO_AUTO_MODULE["config_class"][0]),
  2481. PLACEHOLDER_TO_AUTO_MODULE["config_class"][1],
  2482. ):
  2483. model_name_lowercase = model_name_lowercase.replace("_", "-")
  2484. # Get class name from function's qualified name
  2485. class_name = func.__qualname__.split(".")[0]
  2486. # Get config class for the model
  2487. if model_name_lowercase is None:
  2488. config_class = None
  2489. else:
  2490. try:
  2491. config_class = getattr(
  2492. getattr(auto_module, PLACEHOLDER_TO_AUTO_MODULE["config_class"][0]),
  2493. PLACEHOLDER_TO_AUTO_MODULE["config_class"][1],
  2494. )[model_name_lowercase]
  2495. except KeyError:
  2496. if model_name_lowercase in HARDCODED_CONFIG_FOR_MODELS:
  2497. config_class = HARDCODED_CONFIG_FOR_MODELS[model_name_lowercase]
  2498. else:
  2499. config_class = "ModelConfig"
  2500. print(
  2501. f"[ERROR] Config not found for {model_name_lowercase}. You can manually add it to HARDCODED_CONFIG_FOR_MODELS in utils/auto_docstring.py"
  2502. )
  2503. return model_name_lowercase, class_name, config_class
  2504. def _format_type_annotation_recursive(type_hint):
  2505. """
  2506. Recursively format a type annotation object as a string, preserving generic type arguments.
  2507. This is an internal helper used by process_type_annotation for the type object path.
  2508. Args:
  2509. type_hint: A type annotation object
  2510. Returns:
  2511. str: Formatted type string
  2512. """
  2513. # Handle special cases
  2514. if type_hint is type(...) or type_hint is Ellipsis:
  2515. return "..."
  2516. # Note: NoneType handling is done later to preserve "NoneType" in Union[] but "None" in | syntax
  2517. # Check if this is a generic type (e.g., list[str], dict[str, int])
  2518. origin = get_origin(type_hint)
  2519. args = get_args(type_hint)
  2520. if origin is not None and args:
  2521. # This is a generic type - format it with its arguments
  2522. # Get the origin type name
  2523. if hasattr(origin, "__module__") and hasattr(origin, "__name__"):
  2524. # Clean up module name - need to handle both 'typing.' prefix and just 'typing'
  2525. module_name = origin.__module__
  2526. if module_name in ("typing", "types", "builtins"):
  2527. module_name = ""
  2528. else:
  2529. module_name = (
  2530. module_name.replace("transformers.", "~")
  2531. .replace("typing.", "")
  2532. .replace("types.", "")
  2533. .replace("builtins.", "")
  2534. )
  2535. if module_name:
  2536. origin_str = f"{module_name}.{origin.__name__}"
  2537. else:
  2538. origin_str = origin.__name__
  2539. else:
  2540. origin_str = str(origin)
  2541. # Handle special origin types
  2542. if origin_str == "UnionType":
  2543. # Python 3.13's X | Y syntax - format it nicely
  2544. arg_strs = [_format_type_annotation_recursive(arg) for arg in args]
  2545. return " | ".join(arg_strs)
  2546. # Special handling for Annotated[Union[...], ...] and Annotated[UnionType[...], ...]
  2547. # Check if first arg is a Union/UnionType and format it specially
  2548. if origin_str == "Annotated" and args:
  2549. first_arg_origin = get_origin(args[0])
  2550. # Check if it's a UnionType (modern | syntax) or Union (old Union[] syntax)
  2551. if first_arg_origin is UnionType:
  2552. # Modern union type - format as X | Y | Z (with None not NoneType)
  2553. union_args = get_args(args[0])
  2554. union_strs = []
  2555. for arg in union_args:
  2556. if arg is type(None):
  2557. union_strs.append("None") # Modern syntax uses "None"
  2558. else:
  2559. union_strs.append(_format_type_annotation_recursive(arg))
  2560. formatted_union = " | ".join(union_strs)
  2561. # Include the rest of the Annotated metadata
  2562. remaining_args = [_format_type_annotation_recursive(arg) for arg in args[1:]]
  2563. all_args = [formatted_union] + remaining_args
  2564. return f"{origin_str}[{', '.join(all_args)}]"
  2565. elif first_arg_origin is Union:
  2566. # Old-style Union - format as Union[X, Y, Z]
  2567. union_args = get_args(args[0])
  2568. union_strs = [_format_type_annotation_recursive(arg) for arg in union_args]
  2569. formatted_union = f"Union[{', '.join(union_strs)}]"
  2570. # Include the rest of the Annotated metadata
  2571. remaining_args = [_format_type_annotation_recursive(arg) for arg in args[1:]]
  2572. all_args = [formatted_union] + remaining_args
  2573. return f"{origin_str}[{', '.join(all_args)}]"
  2574. # Recursively format the generic arguments
  2575. arg_strs = [_format_type_annotation_recursive(arg) for arg in args]
  2576. return f"{origin_str}[{', '.join(arg_strs)}]"
  2577. elif hasattr(type_hint, "__module__") and hasattr(type_hint, "__name__"):
  2578. # Simple type with module and name
  2579. # Clean up module name - need to handle both 'typing.' prefix and just 'typing'
  2580. module_name = type_hint.__module__
  2581. if module_name in ("typing", "types", "builtins"):
  2582. module_name = ""
  2583. else:
  2584. module_name = (
  2585. module_name.replace("transformers.", "~")
  2586. .replace("typing.", "")
  2587. .replace("types.", "")
  2588. .replace("builtins.", "")
  2589. )
  2590. if module_name:
  2591. type_name = f"{module_name}.{type_hint.__name__}"
  2592. else:
  2593. type_name = type_hint.__name__
  2594. return type_name
  2595. else:
  2596. # Fallback to string representation
  2597. type_str = str(type_hint)
  2598. # Clean up ForwardRef
  2599. if "ForwardRef" in type_str:
  2600. type_str = _re_forward_ref.sub(r"\1", type_str)
  2601. # Clean up module prefixes
  2602. type_str = type_str.replace("typing.", "").replace("types.", "")
  2603. return type_str
  2604. def process_type_annotation(type_input, param_name: str | None = None) -> tuple[str, bool]:
  2605. """
  2606. Unified function to process and format a parameter's type annotation.
  2607. This function intelligently handles both type objects (from inspect.Parameter.annotation)
  2608. and string representations of types. It will:
  2609. - Use type introspection when given a type object (preserves generic arguments)
  2610. - Parse string representations when that's all that's available
  2611. - Always return a formatted type string and optional flag
  2612. Handles various type representations including:
  2613. - Type objects with generics (e.g., list[str], Optional[int])
  2614. - Union types (both Union[X, Y] and X | Y syntax)
  2615. - Modern union syntax with | (e.g., "bool | None")
  2616. - Complex typing constructs (Union, Optional, Annotated, etc.)
  2617. - Generic types with brackets
  2618. - Class type strings
  2619. - Simple types and module paths
  2620. Args:
  2621. type_input: Either a type annotation object or a string representation of a type
  2622. param_name (`str | None`): The parameter name (used for legacy module path handling)
  2623. Returns:
  2624. tuple[str, bool]: (formatted_type_string, is_optional)
  2625. """
  2626. optional = False
  2627. # Path 1: Type object (best approach - preserves generic type information)
  2628. if not isinstance(type_input, str):
  2629. # Handle None type
  2630. if type_input is None or type_input is type(None):
  2631. return "None", True
  2632. # Handle Union types and modern UnionType (X | Y)
  2633. if get_origin(type_input) is Union or get_origin(type_input) is UnionType:
  2634. subtypes = get_args(type_input)
  2635. out_str = []
  2636. for subtype in subtypes:
  2637. if subtype is type(None):
  2638. optional = True
  2639. continue
  2640. formatted_type = _format_type_annotation_recursive(subtype)
  2641. out_str.append(formatted_type)
  2642. if not out_str:
  2643. return "", optional
  2644. elif len(out_str) == 1:
  2645. return out_str[0], optional
  2646. else:
  2647. return f"Union[{', '.join(out_str)}]", optional
  2648. # Single type (not a Union)
  2649. formatted_type = _format_type_annotation_recursive(type_input)
  2650. return formatted_type, optional
  2651. # Path 2: String representation (fallback when we only have strings)
  2652. param_type = type_input
  2653. # Handle Union types with | syntax
  2654. if " | " in param_type:
  2655. # Modern union syntax (e.g., "bool | None")
  2656. parts = [p.strip() for p in param_type.split(" | ")]
  2657. if "None" in parts:
  2658. optional = True
  2659. parts = [p for p in parts if p != "None"]
  2660. param_type = " | ".join(parts) if parts else ""
  2661. # Clean up module prefixes including typing
  2662. param_type = "".join(param_type.split("typing.")).replace("transformers.", "~").replace("builtins.", "")
  2663. elif "typing" in param_type or "Union[" in param_type or "Optional[" in param_type or "[" in param_type:
  2664. # Complex typing construct or generic type - clean up typing module references
  2665. param_type = "".join(param_type.split("typing.")).replace("transformers.", "~")
  2666. elif "<class '" in param_type:
  2667. # This is a class type like "<class 'module.ClassName'>" - should NOT append param_name
  2668. param_type = (
  2669. param_type.replace("transformers.", "~").replace("builtins.", "").replace("<class '", "").replace("'>", "")
  2670. )
  2671. else:
  2672. # Simple type or module path - only append param_name if it looks like a module path
  2673. # This is legacy behavior for backwards compatibility
  2674. if param_name and "." in param_type and not param_type.split(".")[-1][0].isupper():
  2675. # Looks like a module path ending with an attribute
  2676. param_type = f"{param_type.replace('transformers.', '~').replace('builtins', '')}.{param_name}"
  2677. else:
  2678. # Simple type name, don't append param_name
  2679. param_type = param_type.replace("transformers.", "~").replace("builtins.", "")
  2680. # Clean up ForwardRef
  2681. if "ForwardRef" in param_type:
  2682. param_type = _re_forward_ref.sub(r"\1", param_type)
  2683. # Handle Optional wrapper
  2684. if "Optional" in param_type:
  2685. param_type = _re_optional.sub(r"\1", param_type)
  2686. optional = True
  2687. return param_type, optional
  2688. def _process_parameter_type(param):
  2689. """
  2690. Process and format a parameter's type annotation from an inspect.Parameter object.
  2691. Args:
  2692. param (`inspect.Parameter`): The parameter from the function signature
  2693. Returns:
  2694. tuple[str, bool]: (formatted_type_string, is_optional)
  2695. """
  2696. if param.annotation == inspect.Parameter.empty:
  2697. return "", False
  2698. # Use the unified function to process the type annotation
  2699. formatted_type, optional = process_type_annotation(param.annotation)
  2700. # Check if parameter has a default value (makes it optional)
  2701. if param.default is not inspect.Parameter.empty:
  2702. optional = True
  2703. return formatted_type, optional
  2704. def _get_parameter_info(param_name, documented_params, source_args_dict, param_type, optional):
  2705. """
  2706. Get parameter documentation details from the appropriate source.
  2707. Tensor shape, optional status and description are taken from the custom docstring in priority if available.
  2708. Type is taken from the function signature first, then from the custom docstring if missing from the signature
  2709. Args:
  2710. param_name (`str`): Name of the parameter
  2711. documented_params (`dict`): Dictionary of documented parameters (manually specified in the docstring)
  2712. source_args_dict (`dict`): Default source args dictionary to use if not in documented_params
  2713. param_type (`str`): Current parameter type (may be updated)
  2714. optional (`bool`): Whether the parameter is optional (may be updated)
  2715. """
  2716. description = None
  2717. shape = None
  2718. shape_string = ""
  2719. is_documented = True
  2720. additional_info = None
  2721. optional_string = r", *optional*" if optional else ""
  2722. if param_name in documented_params:
  2723. # Parameter is documented in the function's docstring
  2724. if (
  2725. param_type == ""
  2726. and documented_params[param_name].get("type", None) is not None
  2727. or documented_params[param_name]["additional_info"]
  2728. ):
  2729. param_type = documented_params[param_name]["type"]
  2730. optional = documented_params[param_name]["optional"]
  2731. shape = documented_params[param_name].get("shape", None)
  2732. shape_string = shape if shape else ""
  2733. additional_info = documented_params[param_name]["additional_info"] or ""
  2734. description = f"{documented_params[param_name]['description']}\n"
  2735. elif param_name in source_args_dict:
  2736. # Parameter is documented in ModelArgs or ImageProcessorArgs
  2737. param_type = source_args_dict[param_name].get("type", param_type)
  2738. shape = source_args_dict[param_name].get("shape", None)
  2739. shape_string = " " + shape if shape else ""
  2740. description = source_args_dict[param_name]["description"]
  2741. additional_info = source_args_dict[param_name].get("additional_info", None)
  2742. if additional_info:
  2743. additional_info = shape_string + optional_string + ", " + additional_info
  2744. else:
  2745. # Parameter is not documented
  2746. is_documented = False
  2747. return param_type, optional_string, shape_string, additional_info, description, is_documented
  2748. def _process_regular_parameters(
  2749. sig,
  2750. func,
  2751. class_name,
  2752. documented_params,
  2753. indent_level,
  2754. undocumented_parameters,
  2755. source_args_dict,
  2756. parent_class,
  2757. allowed_params=None,
  2758. ):
  2759. """
  2760. Process all regular parameters (not kwargs parameters) from the function signature.
  2761. Args:
  2762. sig (`inspect.Signature`): Function signature
  2763. func (`function`): Function the parameters belong to
  2764. class_name (`str`): Name of the class
  2765. documented_params (`dict`): Dictionary of parameters that are already documented
  2766. indent_level (`int`): Indentation level
  2767. undocumented_parameters (`list`): List to append undocumented parameters to
  2768. """
  2769. docstring = ""
  2770. # Check if this is a processor by inspecting class hierarchy
  2771. is_processor = _is_processor_class(func, parent_class)
  2772. # Use appropriate args source based on whether it's a processor or not
  2773. if source_args_dict is None:
  2774. if is_processor:
  2775. source_args_dict = get_args_doc_from_source([ModelArgs, ImageProcessorArgs, ProcessorArgs])
  2776. else:
  2777. source_args_dict = get_args_doc_from_source([ModelArgs, ImageProcessorArgs])
  2778. missing_args = {}
  2779. for param_name, param in sig.parameters.items():
  2780. # Skip parameters that should be ignored
  2781. if (
  2782. param_name in ARGS_TO_IGNORE
  2783. or param_name.startswith("_") # Private/internal params (e.g. ClassVar-backed fields in configs)
  2784. or param.kind == inspect.Parameter.VAR_POSITIONAL
  2785. or param.kind == inspect.Parameter.VAR_KEYWORD
  2786. ):
  2787. continue
  2788. # When a filter is active (e.g. config classes: only own annotations), skip inherited params
  2789. if allowed_params is not None and param_name not in allowed_params:
  2790. continue
  2791. # When a filter is active (e.g. config classes: only own annotations), skip inherited params
  2792. if allowed_params is not None and param_name not in allowed_params:
  2793. continue
  2794. param_name = ARGS_TO_RENAME.get(param_name, param_name)
  2795. # Process parameter type and optional status
  2796. param_type, optional = _process_parameter_type(param)
  2797. # Check for default value
  2798. param_default = ""
  2799. if param.default != inspect._empty and param.default is not None:
  2800. param_default = f", defaults to `{str(param.default)}`"
  2801. param_type, optional_string, shape_string, additional_info, description, is_documented = _get_parameter_info(
  2802. param_name, documented_params, source_args_dict, param_type, optional
  2803. )
  2804. if is_documented:
  2805. if param_name == "config":
  2806. if param_type == "":
  2807. param_type = f"[`{class_name}`]"
  2808. else:
  2809. param_type = f"[`{param_type.split('.')[-1]}`]"
  2810. # elif param_type == "" and False: # TODO: Enforce typing for all parameters
  2811. # print(f"[ERROR] {param_name} for {func.__qualname__} in file {func.__code__.co_filename} has no type")
  2812. param_type = param_type if "`" in param_type else f"`{param_type}`"
  2813. # Format the parameter docstring
  2814. if additional_info:
  2815. param_docstring = f"{param_name} ({param_type}{additional_info}):{description}"
  2816. else:
  2817. param_docstring = (
  2818. f"{param_name} ({param_type}{shape_string}{optional_string}{param_default}):{description}"
  2819. )
  2820. docstring += set_min_indent(
  2821. param_docstring,
  2822. indent_level + 8,
  2823. )
  2824. else:
  2825. missing_args[param_name] = {
  2826. "type": param_type if param_type else "<fill_type>",
  2827. "optional": optional,
  2828. "shape": shape_string,
  2829. "description": description if description else "\n <fill_description>",
  2830. "default": param_default,
  2831. }
  2832. # Try to get the correct source file; for classes decorated with @strict (huggingface_hub),
  2833. # func.__code__.co_filename points to the wrapper in huggingface_hub, not the config file.
  2834. try:
  2835. if parent_class is not None:
  2836. _source_file = inspect.getsourcefile(parent_class) or func.__code__.co_filename
  2837. else:
  2838. _source_file = inspect.getsourcefile(inspect.unwrap(func)) or func.__code__.co_filename
  2839. except (TypeError, OSError):
  2840. _source_file = func.__code__.co_filename
  2841. undocumented_parameters.append(
  2842. f"[ERROR] `{param_name}` is part of {func.__qualname__}'s signature, but not documented. Make sure to add it to the docstring of the function in {_source_file}."
  2843. )
  2844. return docstring, missing_args
  2845. def find_sig_line(lines, line_end):
  2846. parenthesis_count = 0
  2847. sig_line_end = line_end
  2848. found_sig = False
  2849. while not found_sig:
  2850. for char in lines[sig_line_end]:
  2851. if char == "(":
  2852. parenthesis_count += 1
  2853. elif char == ")":
  2854. parenthesis_count -= 1
  2855. if parenthesis_count == 0:
  2856. found_sig = True
  2857. break
  2858. sig_line_end += 1
  2859. return sig_line_end
  2860. def _is_image_processor_class(func, parent_class):
  2861. """
  2862. Check if a function belongs to a ProcessorMixin class.
  2863. Uses two methods:
  2864. 1. Check parent_class inheritance (if provided)
  2865. 2. Check if the source file is named processing_*.py (multimodal processors)
  2866. vs image_processing_*.py, video_processing_*.py, etc. (single-modality processors)
  2867. Args:
  2868. func: The function to check
  2869. parent_class: Optional parent class (if available)
  2870. Returns:
  2871. bool: True if this is a multimodal processor (inherits from ProcessorMixin), False otherwise
  2872. """
  2873. # First, check if parent_class is provided and use it
  2874. if parent_class is not None:
  2875. return "BaseImageProcessor" in parent_class.__name__ or any(
  2876. "BaseImageProcessor" in base.__name__ for base in parent_class.__mro__
  2877. )
  2878. # If parent_class is None, check the filename
  2879. # Multimodal processors are in files named "processing_*.py"
  2880. # Single-modality processors are in "image_processing_*.py", "video_processing_*.py", etc.
  2881. try:
  2882. source_file = inspect.getsourcefile(func)
  2883. except TypeError:
  2884. return False
  2885. if not source_file:
  2886. return False
  2887. # Exception for DummyProcessorForTest
  2888. if func.__qualname__.split(".")[0] == "DummyForTestImageProcessorFast":
  2889. return True
  2890. filename = os.path.basename(source_file)
  2891. # Multimodal processors are implemented in processing_*.py modules
  2892. # (single-modality processors use image_processing_*, video_processing_*, etc.)self.
  2893. return filename.startswith("image_processing_") and filename.endswith(".py")
  2894. def _is_processor_class(func, parent_class):
  2895. """
  2896. Check if a function belongs to a ProcessorMixin class.
  2897. Uses two methods:
  2898. 1. Check parent_class inheritance (if provided)
  2899. 2. Check if the source file is named processing_*.py (multimodal processors)
  2900. vs image_processing_*.py, video_processing_*.py, etc. (single-modality processors)
  2901. Args:
  2902. func: The function to check
  2903. parent_class: Optional parent class (if available)
  2904. Returns:
  2905. bool: True if this is a multimodal processor (inherits from ProcessorMixin), False otherwise
  2906. """
  2907. # First, check if parent_class is provided and use it
  2908. if parent_class is not None:
  2909. return "ProcessorMixin" in parent_class.__name__ or any(
  2910. "ProcessorMixin" in base.__name__ for base in parent_class.__mro__
  2911. )
  2912. # If parent_class is None, check the filename
  2913. # Multimodal processors are in files named "processing_*.py"
  2914. # Single-modality processors are in "image_processing_*.py", "video_processing_*.py", etc.
  2915. try:
  2916. source_file = inspect.getsourcefile(func)
  2917. except TypeError:
  2918. return False
  2919. if not source_file:
  2920. return False
  2921. # Exception for DummyProcessorForTest
  2922. if func.__qualname__.split(".")[0] == "DummyProcessorForTest":
  2923. return True
  2924. filename = os.path.basename(source_file)
  2925. # Multimodal processors are implemented in processing_*.py modules
  2926. # (single-modality processors use image_processing_*, video_processing_*, etc.)self.
  2927. return filename.startswith("processing_") and filename.endswith(".py")
  2928. # Python < 3.12 fallback: naming heuristics when __orig_bases__ is not set (cpython#103699).
  2929. # Order matters: check ImageProcessorKwargs before ProcessorKwargs.
  2930. _BASIC_KWARGS_NAMES = frozenset({"ImagesKwargs", "ProcessingKwargs", "TextKwargs", "VideosKwargs", "AudioKwargs"})
  2931. _BASIC_KWARGS_CLASSES = None # Lazy-loaded name -> class mapping
  2932. def _get_base_kwargs_class_from_name(cls_name: str) -> str | None:
  2933. """Map kwargs class name to base using naming conventions. Returns base class name or None."""
  2934. if cls_name in _BASIC_KWARGS_NAMES:
  2935. return cls_name
  2936. if "ImageProcessorKwargs" in cls_name or cls_name.endswith("ImagesKwargs"):
  2937. return "ImagesKwargs"
  2938. if "ProcessorKwargs" in cls_name:
  2939. return "ProcessingKwargs"
  2940. if "VideoProcessorKwargs" in cls_name or cls_name.endswith("VideosKwargs"):
  2941. return "VideosKwargs"
  2942. if "AudioProcessorKwargs" in cls_name or cls_name.endswith("AudioKwargs"):
  2943. return "AudioKwargs"
  2944. if "TextKwargs" in cls_name:
  2945. return "TextKwargs"
  2946. return None
  2947. def _get_base_kwargs_class(cls):
  2948. """
  2949. Get the root/base TypedDict class by walking the inheritance chain.
  2950. For model-specific kwargs like ComplexProcessingKwargs(ProcessingKwargs), returns ProcessingKwargs.
  2951. For model-specific kwargs like DummyImageProcessorKwargs(ImagesKwargs), returns ImagesKwargs.
  2952. Compatibility: On Python < 3.12, non-generic TypedDict subclasses do not have __orig_bases__ set
  2953. (cpython#103699). We fall back to naming heuristics (e.g. *ImageProcessorKwargs -> ImagesKwargs).
  2954. """
  2955. current = cls
  2956. while True:
  2957. bases = typing_extensions.get_original_bases(current)
  2958. parent = None
  2959. for base in bases:
  2960. if isinstance(base, type) and base not in (dict, object):
  2961. if getattr(base, "__name__", "") == "TypedDict" and getattr(base, "__module__", "") == "typing":
  2962. continue
  2963. parent = base
  2964. break
  2965. if parent is None:
  2966. # Python < 3.12 fallback: use naming heuristics
  2967. base_name = _get_base_kwargs_class_from_name(current.__name__)
  2968. if base_name is not None:
  2969. global _BASIC_KWARGS_CLASSES
  2970. if _BASIC_KWARGS_CLASSES is None:
  2971. from transformers.processing_utils import (
  2972. AudioKwargs,
  2973. ImagesKwargs,
  2974. ProcessingKwargs,
  2975. TextKwargs,
  2976. VideosKwargs,
  2977. )
  2978. _BASIC_KWARGS_CLASSES = {
  2979. "ImagesKwargs": ImagesKwargs,
  2980. "ProcessingKwargs": ProcessingKwargs,
  2981. "TextKwargs": TextKwargs,
  2982. "VideosKwargs": VideosKwargs,
  2983. "AudioKwargs": AudioKwargs,
  2984. }
  2985. parent = _BASIC_KWARGS_CLASSES[base_name]
  2986. if parent is None or parent == current:
  2987. return current
  2988. current = parent
  2989. def _process_kwargs_parameters(sig, func, parent_class, documented_kwargs, indent_level, undocumented_parameters):
  2990. """
  2991. Process **kwargs parameters if needed.
  2992. Args:
  2993. sig (`inspect.Signature`): Function signature
  2994. func (`function`): Function the parameters belong to
  2995. parent_class (`class`): Parent class of the function
  2996. documented_kwargs (`dict`): Dictionary of kwargs that are already documented
  2997. indent_level (`int`): Indentation level
  2998. undocumented_parameters (`list`): List to append undocumented parameters to
  2999. Returns:
  3000. tuple[str, str]: (kwargs docstring, kwargs summary line to add after return_tensors)
  3001. """
  3002. docstring = ""
  3003. kwargs_summary = ""
  3004. # Check if we need to add typed kwargs description to the docstring
  3005. unroll_kwargs = func.__name__ in UNROLL_KWARGS_METHODS
  3006. if not unroll_kwargs and parent_class is not None:
  3007. # Check if the function has a parent class with unroll kwargs
  3008. unroll_kwargs = any(
  3009. any(unroll_kwargs_class in base.__name__ for base in parent_class.__mro__)
  3010. for unroll_kwargs_class in UNROLL_KWARGS_CLASSES
  3011. )
  3012. if not unroll_kwargs:
  3013. return docstring, kwargs_summary
  3014. # Check if this is a processor by inspecting class hierarchy
  3015. is_processor = _is_processor_class(func, parent_class)
  3016. is_image_processor = _is_image_processor_class(func, parent_class)
  3017. # Use appropriate args source based on whether it's a processor or not
  3018. if is_processor:
  3019. source_args_dict = get_args_doc_from_source([ImageProcessorArgs, ProcessorArgs])
  3020. elif is_image_processor:
  3021. source_args_dict = get_args_doc_from_source(ImageProcessorArgs)
  3022. else:
  3023. raise ValueError(
  3024. f"Unrolling kwargs is not supported for {func.__name__} of {parent_class.__name__ if parent_class else 'None'} class"
  3025. )
  3026. # get all unpackable "kwargs" parameters
  3027. kwargs_parameters = [
  3028. kwargs_param
  3029. for _, kwargs_param in sig.parameters.items()
  3030. if kwargs_param.kind == inspect.Parameter.VAR_KEYWORD
  3031. ]
  3032. for kwarg_param in kwargs_parameters:
  3033. # If kwargs not typed, skip
  3034. if kwarg_param.annotation == inspect.Parameter.empty:
  3035. continue
  3036. if not hasattr(kwarg_param.annotation, "__args__") or not hasattr(
  3037. kwarg_param.annotation.__args__[0], "__name__"
  3038. ):
  3039. continue
  3040. if kwarg_param.annotation.__args__[0].__name__ not in BASIC_KWARGS_TYPES:
  3041. # Extract documentation for kwargs
  3042. kwargs_documentation = kwarg_param.annotation.__args__[0].__doc__
  3043. if kwargs_documentation is not None:
  3044. documented_kwargs = parse_docstring(kwargs_documentation)[0]
  3045. # Process each kwarg parameter
  3046. for param_name, param_type_annotation in kwarg_param.annotation.__args__[0].__annotations__.items():
  3047. # Handle nested kwargs structures for processors
  3048. if is_processor and param_name.endswith("_kwargs"):
  3049. # Check if this is a basic kwargs type that should be skipped
  3050. # Basic kwargs types are generic containers that shouldn't be documented as individual params
  3051. # Get the actual type (unwrap Optional if needed)
  3052. actual_type = param_type_annotation
  3053. type_name = getattr(param_type_annotation, "__name__", None)
  3054. if type_name is None and hasattr(param_type_annotation, "__origin__"):
  3055. # Handle Optional[Type] or Union cases
  3056. args = getattr(param_type_annotation, "__args__", ())
  3057. for arg in args:
  3058. if arg is not type(None):
  3059. actual_type = arg
  3060. type_name = getattr(arg, "__name__", None)
  3061. break
  3062. # Skip only if it's one of the basic kwargs types
  3063. if type_name in BASIC_KWARGS_TYPES:
  3064. continue
  3065. # Otherwise, unroll the custom typed kwargs
  3066. # Get the nested TypedDict's annotations
  3067. if hasattr(actual_type, "__annotations__"):
  3068. nested_kwargs_doc = getattr(actual_type, "__doc__", None)
  3069. documented_nested_kwargs = {}
  3070. if nested_kwargs_doc:
  3071. documented_nested_kwargs = parse_docstring(nested_kwargs_doc)[0]
  3072. # Only process fields that are documented in the custom kwargs class's own docstring
  3073. # This prevents showing too many inherited parameters
  3074. if not documented_nested_kwargs:
  3075. # No documentation in the custom kwargs class, skip unrolling
  3076. continue
  3077. # Process each field in the custom typed kwargs
  3078. for nested_param_name, nested_param_type in actual_type.__annotations__.items():
  3079. # Only document parameters that are explicitly documented in the TypedDict's docstring
  3080. if nested_param_name not in documented_nested_kwargs:
  3081. continue
  3082. nested_param_type_str, nested_optional = process_type_annotation(
  3083. nested_param_type, nested_param_name
  3084. )
  3085. # Check for default value
  3086. nested_param_default = ""
  3087. if parent_class is not None:
  3088. nested_param_default = str(getattr(parent_class, nested_param_name, ""))
  3089. nested_param_default = (
  3090. f", defaults to `{nested_param_default}`" if nested_param_default != "" else ""
  3091. )
  3092. # Only use the TypedDict's own docstring, not source_args_dict
  3093. # This prevents pulling in too many inherited parameters
  3094. (
  3095. nested_param_type_str,
  3096. nested_optional_string,
  3097. nested_shape_string,
  3098. nested_additional_info,
  3099. nested_description,
  3100. nested_is_documented,
  3101. ) = _get_parameter_info(
  3102. nested_param_name,
  3103. documented_nested_kwargs,
  3104. {}, # Empty dict - only use TypedDict's own docstring
  3105. nested_param_type_str,
  3106. nested_optional,
  3107. )
  3108. # nested_is_documented should always be True here since we filter for it above
  3109. # Check if type is missing
  3110. if nested_param_type_str == "":
  3111. print(
  3112. f"🚨 {nested_param_name} for {type_name} in file {func.__code__.co_filename} has no type"
  3113. )
  3114. nested_param_type_str = (
  3115. nested_param_type_str if "`" in nested_param_type_str else f"`{nested_param_type_str}`"
  3116. )
  3117. # Format the parameter docstring (KWARGS_INDICATOR distinguishes from regular args)
  3118. if nested_additional_info:
  3119. docstring += set_min_indent(
  3120. f"{nested_param_name} ({nested_param_type_str}{KWARGS_INDICATOR}{nested_additional_info}):{nested_description}",
  3121. indent_level + 8,
  3122. )
  3123. else:
  3124. docstring += set_min_indent(
  3125. f"{nested_param_name} ({nested_param_type_str}{KWARGS_INDICATOR}{nested_shape_string}{nested_optional_string}{nested_param_default}):{nested_description}",
  3126. indent_level + 8,
  3127. )
  3128. # Skip processing the _kwargs parameter itself since we've processed its contents
  3129. continue
  3130. else:
  3131. # If we can't get annotations, skip this parameter
  3132. continue
  3133. if documented_kwargs and param_name not in documented_kwargs:
  3134. continue
  3135. param_type, optional = process_type_annotation(param_type_annotation, param_name)
  3136. # Check for default value
  3137. param_default = ""
  3138. if parent_class is not None:
  3139. param_default = str(getattr(parent_class, param_name, ""))
  3140. param_default = f", defaults to `{param_default}`" if param_default != "" else ""
  3141. param_type, optional_string, shape_string, additional_info, description, is_documented = (
  3142. _get_parameter_info(param_name, documented_kwargs, source_args_dict, param_type, optional)
  3143. )
  3144. if is_documented:
  3145. # Check if type is missing
  3146. if param_type == "":
  3147. print(
  3148. f"[ERROR] {param_name} for {kwarg_param.annotation.__args__[0].__qualname__} in file {func.__code__.co_filename} has no type"
  3149. )
  3150. param_type = param_type if "`" in param_type else f"`{param_type}`"
  3151. # Format the parameter docstring (KWARGS_INDICATOR distinguishes from regular args)
  3152. if additional_info:
  3153. docstring += set_min_indent(
  3154. f"{param_name} ({param_type}{KWARGS_INDICATOR}{additional_info}):{description}",
  3155. indent_level + 8,
  3156. )
  3157. else:
  3158. docstring += set_min_indent(
  3159. f"{param_name} ({param_type}{KWARGS_INDICATOR}{shape_string}{optional_string}{param_default}):{description}",
  3160. indent_level + 8,
  3161. )
  3162. else:
  3163. undocumented_parameters.append(
  3164. f"[ERROR] `{param_name}` is part of {kwarg_param.annotation.__args__[0].__qualname__}, but not documented. Make sure to add it to the docstring of the function in {func.__code__.co_filename}."
  3165. )
  3166. # Build **kwargs summary line (added after return_tensors in _process_parameters_section)
  3167. kwargs_annot_cls = kwarg_param.annotation.__args__[0]
  3168. kwargs_type_name = _get_base_kwargs_class(kwargs_annot_cls).__name__
  3169. kwargs_info = source_args_dict.get("__kwargs__", {})
  3170. kwargs_description = kwargs_info.get(
  3171. "description",
  3172. "Additional keyword arguments. Model-specific parameters are listed above.",
  3173. )
  3174. kwargs_summary = set_min_indent(
  3175. f"**kwargs ([`{kwargs_type_name}`], *optional*):{kwargs_description}",
  3176. indent_level + 8,
  3177. )
  3178. return docstring, kwargs_summary
  3179. def _add_return_tensors_to_docstring(func, parent_class, docstring, indent_level):
  3180. """
  3181. Add return_tensors parameter documentation for processor __call__ methods if not already present.
  3182. Args:
  3183. func (`function`): Function being processed
  3184. parent_class (`class`): Parent class of the function
  3185. docstring (`str`): Current docstring being built
  3186. indent_level (`int`): Indentation level
  3187. Returns:
  3188. str: Updated docstring with return_tensors if applicable
  3189. """
  3190. # Check if this is a processor __call__ method or an image processor preprocess method
  3191. is_processor_call = False
  3192. is_image_processor_preprocess = False
  3193. if func.__name__ == "__call__":
  3194. # Check if this is a processor by inspecting class hierarchy
  3195. is_processor_call = _is_processor_class(func, parent_class)
  3196. if func.__name__ == "preprocess":
  3197. is_image_processor_preprocess = _is_image_processor_class(func, parent_class)
  3198. # If it's a processor __call__ method or an image processor preprocess method and return_tensors is not already documented
  3199. if (is_processor_call or is_image_processor_preprocess) and "return_tensors" not in docstring:
  3200. # Get the return_tensors documentation from ImageProcessorArgs
  3201. source_args_dict = (
  3202. get_args_doc_from_source(ProcessorArgs)
  3203. if is_processor_call
  3204. else get_args_doc_from_source(ImageProcessorArgs)
  3205. )
  3206. return_tensors_info = source_args_dict["return_tensors"]
  3207. param_type = return_tensors_info.get("type", "`str` or [`~utils.TensorType`]")
  3208. description = return_tensors_info["description"]
  3209. # Format the parameter type
  3210. param_type = param_type if "`" in param_type else f"`{param_type}`"
  3211. # Format the parameter docstring
  3212. param_docstring = f"return_tensors ({param_type}, *optional*):{description}"
  3213. docstring += set_min_indent(param_docstring, indent_level + 8)
  3214. return docstring
  3215. def _process_parameters_section(
  3216. func_documentation,
  3217. sig,
  3218. func,
  3219. class_name,
  3220. model_name_lowercase,
  3221. parent_class,
  3222. indent_level,
  3223. source_args_dict,
  3224. allowed_params,
  3225. ):
  3226. """
  3227. Process the parameters section of the docstring.
  3228. Args:
  3229. func_documentation (`str`): Existing function documentation (manually specified in the docstring)
  3230. sig (`inspect.Signature`): Function signature
  3231. func (`function`): Function the parameters belong to
  3232. class_name (`str`): Name of the class the function belongs to
  3233. model_name_lowercase (`str`): Lowercase model name
  3234. parent_class (`class`): Parent class of the function (if any)
  3235. indent_level (`int`): Indentation level
  3236. """
  3237. # Start Args section — constant string, min_indent is always 0, so skip set_min_indent
  3238. docstring = " " * (indent_level + 4) + "Args:\n"
  3239. undocumented_parameters = []
  3240. documented_params = {}
  3241. documented_kwargs = {}
  3242. # Parse existing docstring if available
  3243. if func_documentation is not None:
  3244. documented_params, func_documentation = parse_docstring(func_documentation)
  3245. # Process regular parameters
  3246. param_docstring, missing_args = _process_regular_parameters(
  3247. sig,
  3248. func,
  3249. class_name,
  3250. documented_params,
  3251. indent_level,
  3252. undocumented_parameters,
  3253. source_args_dict,
  3254. parent_class,
  3255. allowed_params,
  3256. )
  3257. docstring += param_docstring
  3258. # Process **kwargs parameters if needed
  3259. kwargs_docstring, kwargs_summary = _process_kwargs_parameters(
  3260. sig, func, parent_class, documented_kwargs, indent_level, undocumented_parameters
  3261. )
  3262. docstring += kwargs_docstring
  3263. # Add return_tensors for processor __call__ methods if not already present
  3264. docstring = _add_return_tensors_to_docstring(func, parent_class, docstring, indent_level)
  3265. # Add **kwargs summary line after return_tensors
  3266. docstring += kwargs_summary
  3267. # Report undocumented parameters
  3268. if len(undocumented_parameters) > 0:
  3269. print("\n".join(undocumented_parameters))
  3270. return docstring
  3271. def _prepare_return_docstring(output_type, config_class, add_intro=True):
  3272. """
  3273. Prepare the return docstring from a ModelOutput class.
  3274. This is a robust replacement for the old _prepare_output_docstrings from doc.py,
  3275. using the same parsing and formatting methods as the rest of auto_docstring.
  3276. Args:
  3277. output_type: The ModelOutput class to generate documentation for
  3278. config_class (`str`): Config class for the model
  3279. add_intro (`bool`): Whether to add the introduction text
  3280. Returns:
  3281. str: Formatted return docstring
  3282. """
  3283. output_docstring = output_type.__doc__
  3284. # If the class has no docstring, try to use the parent class's docstring
  3285. if output_docstring is None and hasattr(output_type, "__mro__"):
  3286. for base in output_type.__mro__[1:]: # Skip the class itself
  3287. if base.__doc__ is not None:
  3288. output_docstring = base.__doc__
  3289. break
  3290. if output_docstring is None:
  3291. if add_intro:
  3292. raise ValueError(
  3293. f"No docstring found for `{output_type.__name__}` or its parent classes. "
  3294. "Make sure the ModelOutput class or one of its parents has a docstring."
  3295. )
  3296. return ""
  3297. # Parse the output class docstring to extract parameters
  3298. documented_params, _ = parse_docstring(output_docstring)
  3299. if not documented_params and add_intro:
  3300. raise ValueError(
  3301. f"No `Args` or `Parameters` section is found in the docstring of `{output_type.__name__}`. "
  3302. "Make sure it has a docstring and contains either `Args` or `Parameters`."
  3303. )
  3304. # Build the return section
  3305. full_output_type, _ = process_type_annotation(output_type)
  3306. if add_intro:
  3307. # Import here to avoid circular import
  3308. from .doc import PT_RETURN_INTRODUCTION
  3309. intro = PT_RETURN_INTRODUCTION.format(full_output_type=full_output_type, config_class=config_class)
  3310. else:
  3311. intro = f"Returns:\n `{full_output_type}`"
  3312. if documented_params:
  3313. intro += ":\n"
  3314. else:
  3315. intro += "\n"
  3316. # Build the parameters section
  3317. params_text = ""
  3318. if documented_params:
  3319. for param_name, param_info in documented_params.items():
  3320. param_type = param_info.get("type", "")
  3321. param_description = param_info.get("description", "").strip()
  3322. additional_info = param_info.get("additional_info", "")
  3323. # Handle types with unbalanced backticks due to nested parentheses
  3324. # The parse_docstring function splits types like `tuple(torch.FloatTensor)` incorrectly
  3325. # so we need to reconstruct the complete type by grabbing the closing part from additional_info
  3326. if param_type.startswith("`") and not param_type.endswith("`"):
  3327. # Find the closing backtick in additional_info
  3328. closing_backtick_idx = additional_info.find("`")
  3329. if closing_backtick_idx != -1:
  3330. # Grab everything up to and including the closing backtick
  3331. param_type += additional_info[: closing_backtick_idx + 1]
  3332. # Remove that part from additional_info
  3333. additional_info = additional_info[closing_backtick_idx + 1 :]
  3334. # Strip backticks from type to add them back consistently
  3335. param_type = param_type.strip("`")
  3336. # Use process_type_annotation to ensure consistent type formatting
  3337. # This applies the same formatting rules as the rest of auto_docstring
  3338. if param_type:
  3339. param_type, _ = process_type_annotation(param_type)
  3340. # Build the parameter line
  3341. if additional_info:
  3342. # additional_info contains shape and optional status
  3343. param_line = f"- **{param_name}** (`{param_type}`{additional_info}) -- {param_description}"
  3344. else:
  3345. param_line = f"- **{param_name}** (`{param_type}`) -- {param_description}"
  3346. # Handle multi-line descriptions:
  3347. # Split the description to handle continuations with proper indentation
  3348. lines = param_line.split("\n")
  3349. formatted_lines = []
  3350. for i, line in enumerate(lines):
  3351. if i == 0:
  3352. # First line gets no extra indent (just the bullet point)
  3353. formatted_lines.append(line)
  3354. else:
  3355. # Continuation lines: strip existing indentation and add 2 spaces (relative to the bullet)
  3356. formatted_lines.append(" " + line.lstrip())
  3357. param_text = "\n".join(formatted_lines)
  3358. # Indent everything to 4 spaces and append with newline
  3359. param_text_indented = set_min_indent(param_text, 4)
  3360. params_text += param_text_indented + "\n"
  3361. result = intro + params_text
  3362. return result
  3363. def _process_returns_section(func_documentation, sig, config_class, indent_level):
  3364. """
  3365. Process the returns section of the docstring.
  3366. Args:
  3367. func_documentation (`str`): Existing function documentation (manually specified in the docstring)
  3368. sig (`inspect.Signature`): Function signature
  3369. config_class (`str`): Config class for the model
  3370. indent_level (`int`): Indentation level
  3371. """
  3372. return_docstring = ""
  3373. # Extract returns section from existing docstring if available
  3374. if func_documentation is not None and (match_start := _re_return.search(func_documentation)) is not None:
  3375. match_end = _re_example.search(func_documentation)
  3376. if match_end:
  3377. return_docstring = func_documentation[match_start.start() : match_end.start()]
  3378. func_documentation = func_documentation[match_end.start() :]
  3379. else:
  3380. return_docstring = func_documentation[match_start.start() :]
  3381. func_documentation = ""
  3382. return_docstring = set_min_indent(return_docstring, indent_level + 4)
  3383. # Otherwise, generate return docstring from return annotation if available
  3384. elif sig.return_annotation is not None and sig.return_annotation != inspect._empty:
  3385. add_intro, return_annotation = contains_type(sig.return_annotation, ModelOutput)
  3386. return_docstring = _prepare_return_docstring(return_annotation, config_class, add_intro=add_intro)
  3387. # PT_RETURN_INTRODUCTION already starts with \n, so only add blank line if it doesn't start with one
  3388. if not return_docstring.startswith("\n"):
  3389. return_docstring = "\n" + return_docstring
  3390. return_docstring = set_min_indent(return_docstring, indent_level + 4)
  3391. return return_docstring, func_documentation
  3392. def _process_example_section(
  3393. func_documentation, func, parent_class, class_name, model_name_lowercase, config_class, checkpoint, indent_level
  3394. ):
  3395. """
  3396. Process the example section of the docstring.
  3397. Args:
  3398. func_documentation (`str`): Existing function documentation (manually specified in the docstring)
  3399. func (`function`): Function being processed
  3400. parent_class (`class`): Parent class of the function
  3401. class_name (`str`): Name of the class
  3402. model_name_lowercase (`str`): Lowercase model name
  3403. config_class (`str`): Config class for the model
  3404. checkpoint: Checkpoint to use in examples
  3405. indent_level (`int`): Indentation level
  3406. """
  3407. # Import here to avoid circular import
  3408. from transformers.models import auto as auto_module
  3409. example_docstring = ""
  3410. # Use existing example section if available (with or without an "Example:" header)
  3411. if func_documentation is not None and (match := _re_example.search(func_documentation)):
  3412. example_docstring = func_documentation[match.start() :]
  3413. example_docstring = "\n" + set_min_indent(example_docstring, indent_level + 4)
  3414. # Skip examples for processors
  3415. elif _is_processor_class(func, parent_class):
  3416. # Processors don't get auto-generated examples
  3417. return example_docstring
  3418. # No examples for __init__ methods or if the class is not a model
  3419. elif parent_class is None and model_name_lowercase is not None:
  3420. global _re_model_task
  3421. if _re_model_task is None:
  3422. _re_model_task = re.compile(rf"({'|'.join(PT_SAMPLE_DOCSTRINGS.keys())})")
  3423. model_task = _re_model_task.search(class_name)
  3424. CONFIG_MAPPING = auto_module.configuration_auto.CONFIG_MAPPING
  3425. # Get checkpoint example
  3426. if (checkpoint_example := checkpoint) is None:
  3427. try:
  3428. checkpoint_example = get_checkpoint_from_config_class(CONFIG_MAPPING[model_name_lowercase])
  3429. except KeyError:
  3430. # For models with inconsistent lowercase model name
  3431. if model_name_lowercase in HARDCODED_CONFIG_FOR_MODELS:
  3432. CONFIG_MAPPING_NAMES = auto_module.configuration_auto.CONFIG_MAPPING_NAMES
  3433. config_class_name = HARDCODED_CONFIG_FOR_MODELS[model_name_lowercase]
  3434. if config_class_name in CONFIG_MAPPING_NAMES.values():
  3435. model_name_for_auto_config = [
  3436. k for k, v in CONFIG_MAPPING_NAMES.items() if v == config_class_name
  3437. ][0]
  3438. if model_name_for_auto_config in CONFIG_MAPPING:
  3439. checkpoint_example = get_checkpoint_from_config_class(
  3440. CONFIG_MAPPING[model_name_for_auto_config]
  3441. )
  3442. # Add example based on model task
  3443. if model_task is not None:
  3444. if checkpoint_example is not None:
  3445. example_annotation = ""
  3446. task = model_task.group()
  3447. example_annotation = PT_SAMPLE_DOCSTRINGS[task].format(
  3448. model_class=class_name,
  3449. checkpoint=checkpoint_example,
  3450. expected_output="...",
  3451. expected_loss="...",
  3452. qa_target_start_index=14,
  3453. qa_target_end_index=15,
  3454. mask="<mask>",
  3455. )
  3456. example_docstring = set_min_indent(example_annotation, indent_level + 4)
  3457. else:
  3458. print(
  3459. f"[ERROR] No checkpoint found for {class_name}.{func.__name__}. Please add a `checkpoint` arg to `auto_docstring` or add one in {config_class}'s docstring"
  3460. )
  3461. else:
  3462. # Check if the model is in a pipeline to get an example
  3463. for name_model_list_for_task in MODELS_TO_PIPELINE:
  3464. try:
  3465. model_list_for_task = getattr(auto_module.modeling_auto, name_model_list_for_task)
  3466. except (ImportError, AttributeError):
  3467. continue
  3468. if class_name in model_list_for_task.values():
  3469. pipeline_name = MODELS_TO_PIPELINE[name_model_list_for_task]
  3470. example_annotation = PIPELINE_TASKS_TO_SAMPLE_DOCSTRINGS[pipeline_name].format(
  3471. model_class=class_name,
  3472. checkpoint=checkpoint_example,
  3473. expected_output="...",
  3474. expected_loss="...",
  3475. qa_target_start_index=14,
  3476. qa_target_end_index=15,
  3477. )
  3478. example_docstring = set_min_indent(example_annotation, indent_level + 4)
  3479. break
  3480. return example_docstring
  3481. def auto_method_docstring(
  3482. func,
  3483. parent_class=None,
  3484. custom_intro=None,
  3485. custom_args=None,
  3486. checkpoint=None,
  3487. source_args_dict=None,
  3488. allowed_params=None,
  3489. ):
  3490. """
  3491. Wrapper that automatically generates docstring.
  3492. """
  3493. # Use inspect to retrieve the method's signature
  3494. sig = inspect.signature(func)
  3495. indent_level = get_indent_level(func) if not parent_class else get_indent_level(parent_class)
  3496. # Get model information
  3497. model_name_lowercase, class_name, config_class = _get_model_info(func, parent_class)
  3498. func_documentation = func.__doc__
  3499. if custom_args is not None and func_documentation is not None:
  3500. func_documentation = "\n" + set_min_indent(custom_args.strip("\n"), 0) + "\n" + func_documentation
  3501. elif custom_args is not None:
  3502. func_documentation = "\n" + set_min_indent(custom_args.strip("\n"), 0)
  3503. # Add intro to the docstring before args description if needed
  3504. if custom_intro is not None:
  3505. docstring = set_min_indent(custom_intro, indent_level + 4)
  3506. if not docstring.strip().endswith("\n"):
  3507. docstring += "\n"
  3508. else:
  3509. docstring = add_intro_docstring(func, class_name=class_name, indent_level=indent_level)
  3510. # Process Parameters section
  3511. docstring += _process_parameters_section(
  3512. func_documentation,
  3513. sig,
  3514. func,
  3515. class_name,
  3516. model_name_lowercase,
  3517. parent_class,
  3518. indent_level,
  3519. source_args_dict,
  3520. allowed_params,
  3521. )
  3522. # Process Returns section
  3523. return_docstring, func_documentation = _process_returns_section(
  3524. func_documentation, sig, config_class, indent_level
  3525. )
  3526. docstring += return_docstring
  3527. # Process Example section
  3528. example_docstring = _process_example_section(
  3529. func_documentation,
  3530. func,
  3531. parent_class,
  3532. class_name,
  3533. model_name_lowercase,
  3534. config_class,
  3535. checkpoint,
  3536. indent_level,
  3537. )
  3538. docstring += example_docstring
  3539. # Format the docstring with the placeholders
  3540. docstring = format_args_docstring(docstring, model_name_lowercase)
  3541. # Assign the dynamically generated docstring to the wrapper function
  3542. func.__doc__ = docstring
  3543. return func
  3544. def auto_class_docstring(cls, custom_intro=None, custom_args=None, checkpoint=None):
  3545. """
  3546. Wrapper that automatically generates a docstring for classes based on their attributes and methods.
  3547. """
  3548. # import here to avoid circular import
  3549. from transformers.models import auto as auto_module
  3550. is_dataclass = False
  3551. is_processor = False
  3552. is_config = False
  3553. is_image_processor = False
  3554. docstring_init = ""
  3555. docstring_args = ""
  3556. if "PreTrainedModel" in (x.__name__ for x in cls.__mro__):
  3557. docstring_init = auto_method_docstring(
  3558. cls.__init__, parent_class=cls, custom_args=custom_args, checkpoint=checkpoint
  3559. ).__doc__.replace("Args:", "Parameters:")
  3560. elif "ProcessorMixin" in (x.__name__ for x in cls.__mro__):
  3561. is_processor = True
  3562. docstring_init = auto_method_docstring(
  3563. cls.__init__,
  3564. parent_class=cls,
  3565. custom_args=custom_args,
  3566. checkpoint=checkpoint,
  3567. source_args_dict=get_args_doc_from_source([ModelArgs, ImageProcessorArgs, ProcessorArgs]),
  3568. ).__doc__.replace("Args:", "Parameters:")
  3569. elif "ModelOutput" in (x.__name__ for x in cls.__mro__):
  3570. # We have a data class
  3571. is_dataclass = True
  3572. doc_class = cls.__doc__
  3573. if custom_args is None and doc_class:
  3574. custom_args = doc_class
  3575. docstring_args = auto_method_docstring(
  3576. cls.__init__,
  3577. parent_class=cls,
  3578. custom_args=custom_args,
  3579. checkpoint=checkpoint,
  3580. source_args_dict=get_args_doc_from_source(ModelOutputArgs),
  3581. ).__doc__
  3582. elif any("BaseImageProcessor" in x.__name__ for x in cls.__mro__):
  3583. is_image_processor = True
  3584. docstring_init = auto_method_docstring(
  3585. cls.__init__,
  3586. parent_class=cls,
  3587. custom_args=custom_args,
  3588. checkpoint=checkpoint,
  3589. source_args_dict=get_args_doc_from_source(ImageProcessorArgs),
  3590. ).__doc__
  3591. elif "PreTrainedConfig" in (x.__name__ for x in cls.__mro__):
  3592. is_config = True
  3593. doc_class = cls.__doc__
  3594. if custom_args is None and doc_class:
  3595. custom_args = doc_class
  3596. # Collect all non-ClassVar annotations from the class and its ancestors up to
  3597. # (but not including) PreTrainedConfig. This allows inherited params from intermediate
  3598. # config base classes to be documented, while naturally excluding PreTrainedConfig-specific
  3599. # quasi-ClassVar params (e.g. `transformers_version`, `architectures`).
  3600. own_config_params = set()
  3601. for ancestor in cls.__mro__:
  3602. if ancestor.__name__ == "PreTrainedConfig":
  3603. break
  3604. own_config_params |= {
  3605. k for k, v in getattr(ancestor, "__annotations__", {}).items() if get_origin(v) is not ClassVar
  3606. }
  3607. allowed_params = own_config_params if own_config_params else None
  3608. docstring_init = auto_method_docstring(
  3609. cls.__init__,
  3610. parent_class=cls,
  3611. custom_args=custom_args,
  3612. checkpoint=checkpoint,
  3613. source_args_dict=get_args_doc_from_source([ConfigArgs]),
  3614. allowed_params=allowed_params,
  3615. ).__doc__
  3616. indent_level = get_indent_level(cls)
  3617. model_name_lowercase = get_model_name(cls)
  3618. model_name_title = " ".join([k.title() for k in model_name_lowercase.split("_")]) if model_name_lowercase else None
  3619. model_base_class = f"{model_name_title.title()}Model" if model_name_title is not None else None
  3620. if model_name_lowercase is not None:
  3621. try:
  3622. model_base_class = getattr(
  3623. getattr(auto_module, PLACEHOLDER_TO_AUTO_MODULE["model_class"][0]),
  3624. PLACEHOLDER_TO_AUTO_MODULE["model_class"][1],
  3625. )[model_name_lowercase]
  3626. except KeyError:
  3627. pass
  3628. except ImportError:
  3629. # In some environments, certain model classes might not be available. In that case, we can skip this part.
  3630. pass
  3631. if model_name_lowercase and model_name_lowercase not in getattr(
  3632. getattr(auto_module, PLACEHOLDER_TO_AUTO_MODULE["config_class"][0]),
  3633. PLACEHOLDER_TO_AUTO_MODULE["config_class"][1],
  3634. ):
  3635. model_name_lowercase = model_name_lowercase.replace("_", "-")
  3636. name = re.findall(rf"({'|'.join(ClassDocstring.__dict__.keys())})$", cls.__name__)
  3637. if name == [] and custom_intro is None and not is_dataclass and not is_processor and not is_image_processor:
  3638. raise ValueError(
  3639. f"`{cls.__name__}` is not registered in the auto doc. Here are the available classes: {ClassDocstring.__dict__.keys()}.\n"
  3640. "Add a `custom_intro` to the decorator if you want to use `auto_docstring` on a class not registered in the auto doc."
  3641. )
  3642. if name != [] or custom_intro is not None or is_config or is_dataclass or is_processor or is_image_processor:
  3643. name = name[0] if name else None
  3644. formatting_kwargs = {"model_name": model_name_title}
  3645. if name == "Config":
  3646. formatting_kwargs.update({"model_base_class": model_base_class, "model_checkpoint": checkpoint})
  3647. if custom_intro is not None:
  3648. pre_block = equalize_indent(custom_intro, indent_level)
  3649. if not pre_block.endswith("\n"):
  3650. pre_block += "\n"
  3651. elif is_processor:
  3652. # Generate processor intro dynamically
  3653. pre_block = generate_processor_intro(cls)
  3654. if pre_block:
  3655. pre_block = equalize_indent(pre_block, indent_level)
  3656. pre_block = format_args_docstring(pre_block, model_name_lowercase)
  3657. elif is_image_processor:
  3658. pre_block = r"Constructs a {image_processor_class} image processor."
  3659. if pre_block:
  3660. pre_block = equalize_indent(pre_block, indent_level)
  3661. pre_block = format_args_docstring(pre_block, model_name_lowercase)
  3662. elif model_name_title is None or name is None:
  3663. pre_block = ""
  3664. else:
  3665. pre_block = getattr(ClassDocstring, name).format(**formatting_kwargs)
  3666. # Start building the docstring
  3667. docstring = set_min_indent(f"{pre_block}", indent_level) if len(pre_block) else ""
  3668. if name != "PreTrainedModel" and "PreTrainedModel" in (x.__name__ for x in cls.__mro__):
  3669. docstring += set_min_indent(f"{ClassDocstring.PreTrainedModel}", indent_level)
  3670. # Add the __init__ docstring
  3671. if docstring_init:
  3672. docstring += set_min_indent(f"\n{docstring_init}", indent_level)
  3673. elif is_dataclass or is_config:
  3674. # No init function, we have a data class
  3675. docstring += docstring_args if docstring_args else "\nArgs:\n"
  3676. source_args_dict = get_args_doc_from_source(ModelOutputArgs)
  3677. doc_class = cls.__doc__ if cls.__doc__ else ""
  3678. documented_kwargs = parse_docstring(doc_class)[0]
  3679. for param_name, param_type_annotation in cls.__annotations__.items():
  3680. param_type, optional = process_type_annotation(param_type_annotation, param_name)
  3681. # Check for default value
  3682. param_default = ""
  3683. param_default = str(getattr(cls, param_name, ""))
  3684. param_default = f", defaults to `{param_default}`" if param_default != "" else ""
  3685. param_type, optional_string, shape_string, additional_info, description, is_documented = (
  3686. _get_parameter_info(param_name, documented_kwargs, source_args_dict, param_type, optional)
  3687. )
  3688. if is_documented:
  3689. # Check if type is missing
  3690. if param_type == "":
  3691. print(
  3692. f"[ERROR] {param_name} for {cls.__qualname__} in file {cls.__code__.co_filename} has no type"
  3693. )
  3694. param_type = param_type if "`" in param_type else f"`{param_type}`"
  3695. # Format the parameter docstring
  3696. if additional_info:
  3697. docstring += set_min_indent(
  3698. f"{param_name} ({param_type}{additional_info}):{description}",
  3699. indent_level + 8,
  3700. )
  3701. else:
  3702. docstring += set_min_indent(
  3703. f"{param_name} ({param_type}{shape_string}{optional_string}{param_default}):{description}",
  3704. indent_level + 8,
  3705. )
  3706. # TODO (Yoni): Add support for Attributes section in docs
  3707. else:
  3708. print(
  3709. f"You used `@auto_class_docstring` decorator on `{cls.__name__}` but this class is not part of the AutoMappings. Remove the decorator"
  3710. )
  3711. # Assign the dynamically generated docstring to the wrapper class
  3712. cls.__doc__ = docstring
  3713. return cls
  3714. def auto_docstring(obj=None, *, custom_intro=None, custom_args=None, checkpoint=None):
  3715. r"""
  3716. Automatically generates comprehensive docstrings for model classes and methods in the Transformers library.
  3717. This decorator reduces boilerplate by automatically including standard argument descriptions while allowing
  3718. overrides to add new or custom arguments. It inspects function signatures, retrieves predefined docstrings
  3719. for common arguments (like `input_ids`, `attention_mask`, etc.), and generates complete documentation
  3720. including examples and return value descriptions.
  3721. For complete documentation and examples, read this [guide](https://huggingface.co/docs/transformers/auto_docstring).
  3722. Examples of usage:
  3723. Basic usage (no parameters):
  3724. ```python
  3725. @auto_docstring
  3726. class MyAwesomeModel(PreTrainedModel):
  3727. def __init__(self, config, custom_parameter: int = 10):
  3728. r'''
  3729. custom_parameter (`int`, *optional*, defaults to 10):
  3730. Description of the custom parameter for MyAwesomeModel.
  3731. '''
  3732. super().__init__(config)
  3733. self.custom_parameter = custom_parameter
  3734. ```
  3735. Using `custom_intro` with a class:
  3736. ```python
  3737. @auto_docstring(
  3738. custom_intro="This model implements a novel attention mechanism for improved performance."
  3739. )
  3740. class MySpecialModel(PreTrainedModel):
  3741. def __init__(self, config, attention_type: str = "standard"):
  3742. r'''
  3743. attention_type (`str`, *optional*, defaults to "standard"):
  3744. Type of attention mechanism to use.
  3745. '''
  3746. super().__init__(config)
  3747. ```
  3748. Using `custom_intro` with a method, and specify custom arguments and example directly in the docstring:
  3749. ```python
  3750. @auto_docstring(
  3751. custom_intro="Performs forward pass with enhanced attention computation."
  3752. )
  3753. def forward(
  3754. self,
  3755. input_ids: Optional[torch.Tensor] = None,
  3756. attention_mask: Optional[torch.Tensor] = None,
  3757. ):
  3758. r'''
  3759. custom_parameter (`int`, *optional*, defaults to 10):
  3760. Description of the custom parameter for MyAwesomeModel.
  3761. Example:
  3762. ```python
  3763. >>> model = MyAwesomeModel(config)
  3764. >>> model.forward(input_ids=torch.tensor([1, 2, 3]), attention_mask=torch.tensor([1, 1, 1]))
  3765. ```
  3766. '''
  3767. ```
  3768. Using `custom_args` to define reusable arguments:
  3769. ```python
  3770. VISION_ARGS = r'''
  3771. pixel_values (`torch.FloatTensor`, *optional*):
  3772. Pixel values of the input images.
  3773. image_features (`torch.FloatTensor`, *optional*):
  3774. Pre-computed image features for efficient processing.
  3775. '''
  3776. @auto_docstring(custom_args=VISION_ARGS)
  3777. def encode_images(self, pixel_values=None, image_features=None):
  3778. # ... method implementation
  3779. ```
  3780. Combining `custom_intro` and `custom_args`:
  3781. ```python
  3782. MULTIMODAL_ARGS = r'''
  3783. vision_features (`torch.FloatTensor`, *optional*):
  3784. Pre-extracted vision features from the vision encoder.
  3785. fusion_strategy (`str`, *optional*, defaults to "concat"):
  3786. Strategy for fusing text and vision modalities.
  3787. '''
  3788. @auto_docstring(
  3789. custom_intro="Processes multimodal inputs combining text and vision.",
  3790. custom_args=MULTIMODAL_ARGS
  3791. )
  3792. def forward(
  3793. self,
  3794. input_ids,
  3795. attention_mask=None,
  3796. vision_features=None,
  3797. fusion_strategy="concat"
  3798. ):
  3799. # ... multimodal processing
  3800. ```
  3801. Using with ModelOutput classes:
  3802. ```python
  3803. @dataclass
  3804. @auto_docstring(
  3805. custom_intro="Custom model outputs with additional fields."
  3806. )
  3807. class MyModelOutput(ImageClassifierOutput):
  3808. r'''
  3809. loss (`torch.FloatTensor`, *optional*):
  3810. The loss of the model.
  3811. custom_field (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*):
  3812. A custom output field specific to this model.
  3813. '''
  3814. # Standard fields like hidden_states, logits, attentions etc. can be automatically documented
  3815. # However, given that the loss docstring is often different per model, you should document it above
  3816. loss: Optional[torch.FloatTensor] = None
  3817. logits: Optional[torch.FloatTensor] = None
  3818. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  3819. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  3820. custom_field: Optional[torch.FloatTensor] = None
  3821. ```
  3822. Args:
  3823. custom_intro (`str`, *optional*):
  3824. Custom introduction text to add to the docstring. This replaces the default
  3825. introduction text generated by the decorator before the Args section. Use this to describe what
  3826. makes your model or method special.
  3827. custom_args (`str`, *optional*):
  3828. Custom argument documentation in docstring format. This allows you to define
  3829. argument descriptions once and reuse them across multiple methods. The format should follow the
  3830. standard docstring convention: `arg_name (`type`, *optional*, defaults to `value`): Description.`
  3831. checkpoint (`str`, *optional*):
  3832. Checkpoint name to use in examples within the docstring. This is typically
  3833. automatically inferred from the model configuration class, but can be overridden if needed for
  3834. custom examples.
  3835. Note:
  3836. - Standard arguments (`input_ids`, `attention_mask`, `pixel_values`, etc.) are automatically documented
  3837. from predefined descriptions and should not be redefined unless their behavior differs in your model.
  3838. - New or custom arguments should be documented in the method's docstring using the `r''' '''` block
  3839. or passed via the `custom_args` parameter.
  3840. - For model classes, the decorator derives parameter descriptions from the `__init__` method's signature
  3841. and docstring.
  3842. - Return value documentation is automatically generated for methods that return ModelOutput subclasses.
  3843. """
  3844. def auto_docstring_decorator(obj):
  3845. if len(obj.__qualname__.split(".")) > 1:
  3846. return auto_method_docstring(
  3847. obj, custom_args=custom_args, custom_intro=custom_intro, checkpoint=checkpoint
  3848. )
  3849. else:
  3850. return auto_class_docstring(obj, custom_args=custom_args, custom_intro=custom_intro, checkpoint=checkpoint)
  3851. if obj:
  3852. return auto_docstring_decorator(obj)
  3853. return auto_docstring_decorator