testing_utils.py 156 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401
  1. # Copyright 2020 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import ast
  15. import collections
  16. import contextlib
  17. import copy
  18. import doctest
  19. import functools
  20. import gc
  21. import importlib
  22. import inspect
  23. import json
  24. import logging
  25. import multiprocessing
  26. import os
  27. import re
  28. import shlex
  29. import shutil
  30. import subprocess
  31. import sys
  32. import tempfile
  33. import threading
  34. import time
  35. import traceback
  36. import types
  37. import unittest
  38. from collections import UserDict, defaultdict
  39. from collections.abc import Callable, Generator, Iterable, Iterator, Mapping
  40. from contextlib import contextmanager
  41. from dataclasses import MISSING, fields
  42. from functools import cache, wraps
  43. from io import StringIO
  44. from pathlib import Path
  45. from typing import TYPE_CHECKING, Any
  46. from unittest import mock
  47. from unittest.mock import patch
  48. import httpx
  49. from huggingface_hub import create_repo, delete_repo
  50. from packaging import version
  51. from transformers import logging as transformers_logging
  52. if TYPE_CHECKING:
  53. from .trainer import Trainer
  54. else:
  55. Trainer = Any # type: ignore
  56. from .integrations import (
  57. is_clearml_available,
  58. is_optuna_available,
  59. is_ray_available,
  60. is_swanlab_available,
  61. is_tensorboard_available,
  62. is_trackio_available,
  63. is_wandb_available,
  64. )
  65. from .integrations.deepspeed import is_deepspeed_available
  66. from .utils import (
  67. ACCELERATE_MIN_VERSION,
  68. GGUF_MIN_VERSION,
  69. SAFE_WEIGHTS_INDEX_NAME,
  70. TRITON_MIN_VERSION,
  71. WEIGHTS_INDEX_NAME,
  72. is_accelerate_available,
  73. is_apex_available,
  74. is_apollo_torch_available,
  75. is_aqlm_available,
  76. is_auto_round_available,
  77. is_av_available,
  78. is_bitsandbytes_available,
  79. is_bs4_available,
  80. is_compressed_tensors_available,
  81. is_cv2_available,
  82. is_cython_available,
  83. is_decord_available,
  84. is_detectron2_available,
  85. is_essentia_available,
  86. is_faiss_available,
  87. is_fbgemm_gpu_available,
  88. is_flash_attn_2_available,
  89. is_flash_attn_3_available,
  90. is_flash_attn_4_available,
  91. is_flute_available,
  92. is_fouroversix_available,
  93. is_fp_quant_available,
  94. is_fsdp_available,
  95. is_g2p_en_available,
  96. is_galore_torch_available,
  97. is_gguf_available,
  98. is_gptqmodel_available,
  99. is_grokadamw_available,
  100. is_hadamard_available,
  101. is_hqq_available,
  102. is_huggingface_hub_greater_or_equal,
  103. is_jinja_available,
  104. is_jmespath_available,
  105. is_jumanpp_available,
  106. is_kernels_available,
  107. is_levenshtein_available,
  108. is_librosa_available,
  109. is_liger_kernel_available,
  110. is_lomo_available,
  111. is_mistral_common_available,
  112. is_multipart_available,
  113. is_natten_available,
  114. is_nltk_available,
  115. is_numba_available,
  116. is_onnx_available,
  117. is_openai_available,
  118. is_optimum_available,
  119. is_optimum_quanto_available,
  120. is_pandas_available,
  121. is_peft_available,
  122. is_phonemizer_available,
  123. is_pretty_midi_available,
  124. is_psutil_available,
  125. is_pyctcdecode_available,
  126. is_pytesseract_available,
  127. is_pytest_available,
  128. is_pytest_order_available,
  129. is_pytorch_quantization_available,
  130. is_quark_available,
  131. is_qutlass_available,
  132. is_rjieba_available,
  133. is_sacremoses_available,
  134. is_schedulefree_available,
  135. is_scipy_available,
  136. is_sentencepiece_available,
  137. is_seqio_available,
  138. is_serve_available,
  139. is_spacy_available,
  140. is_speech_available,
  141. is_spqr_available,
  142. is_sudachi_available,
  143. is_sudachi_projection_available,
  144. is_tiktoken_available,
  145. is_timm_available,
  146. is_tokenizers_available,
  147. is_torch_available,
  148. is_torch_bf16_available_on_device,
  149. is_torch_fp16_available_on_device,
  150. is_torch_greater_or_equal,
  151. is_torch_hpu_available,
  152. is_torch_mlu_available,
  153. is_torch_neuroncore_available,
  154. is_torch_npu_available,
  155. is_torch_optimi_available,
  156. is_torch_tensorrt_fx_available,
  157. is_torch_tf32_available,
  158. is_torch_xla_available,
  159. is_torch_xpu_available,
  160. is_torchao_available,
  161. is_torchaudio_available,
  162. is_torchcodec_available,
  163. is_torchvision_available,
  164. is_triton_available,
  165. is_vision_available,
  166. is_vptq_available,
  167. strtobool,
  168. )
  169. if is_accelerate_available():
  170. from accelerate.state import AcceleratorState, PartialState
  171. from accelerate.utils.imports import is_fp8_available
  172. if is_pytest_available():
  173. from _pytest.doctest import (
  174. Module,
  175. _get_checker,
  176. _get_continue_on_failure,
  177. _get_runner,
  178. _is_mocked,
  179. _patch_unwrap_mock_aware,
  180. get_optionflags,
  181. )
  182. from _pytest.outcomes import skip
  183. from _pytest.pathlib import import_path
  184. from pytest import DoctestItem
  185. else:
  186. Module = object
  187. DoctestItem = object
  188. SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
  189. DUMMY_UNKNOWN_IDENTIFIER = "julien-c/dummy-unknown"
  190. DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer"
  191. # Used to test Auto{Config, Model, Tokenizer} model_type detection.
  192. # Used to test the hub
  193. USER = "__DUMMY_TRANSFORMERS_USER__"
  194. ENDPOINT_STAGING = "https://hub-ci.huggingface.co"
  195. # Not critical, only usable on the sandboxed CI instance.
  196. TOKEN = "hf_94wBhPGp6KrrTH3KDchhKpRxZwd6dmHWLL"
  197. # Used in CausalLMModelTester (and related classes/methods) to infer the common model classes from the base model class
  198. _COMMON_MODEL_NAMES_MAP = {
  199. "config_class": "Config",
  200. "causal_lm_class": "ForCausalLM",
  201. "question_answering_class": "ForQuestionAnswering",
  202. "sequence_classification_class": "ForSequenceClassification",
  203. "token_classification_class": "ForTokenClassification",
  204. }
  205. # Used in VLMModelTester (and related classes/methods) to infer the common model classes from the base model class
  206. _VLM_COMMON_MODEL_NAMES_MAP = {
  207. "config_class": "Config",
  208. "text_config_class": "TextConfig",
  209. "vision_config_class": "VisionConfig",
  210. "conditional_generation_class": "ForConditionalGeneration",
  211. }
  212. if is_torch_available():
  213. import torch
  214. from safetensors.torch import load_file
  215. from .modeling_utils import FLASH_ATTN_KERNEL_FALLBACK, PreTrainedModel
  216. IS_ROCM_SYSTEM = torch.version.hip is not None
  217. IS_CUDA_SYSTEM = torch.version.cuda is not None
  218. IS_XPU_SYSTEM = getattr(torch.version, "xpu", None) is not None
  219. IS_NPU_SYSTEM = getattr(torch, "npu", None) is not None
  220. else:
  221. IS_ROCM_SYSTEM = False
  222. IS_CUDA_SYSTEM = False
  223. IS_XPU_SYSTEM = False
  224. IS_NPU_SYSTEM = False
  225. logger = transformers_logging.get_logger(__name__)
  226. def parse_flag_from_env(key, default=False):
  227. try:
  228. value = os.environ[key]
  229. except KeyError:
  230. # KEY isn't set, default to `default`.
  231. _value = default
  232. else:
  233. # KEY is set, convert it to True or False.
  234. try:
  235. _value = strtobool(value)
  236. except ValueError:
  237. # More values are supported, but let's keep the message simple.
  238. raise ValueError(f"If set, {key} must be yes or no.")
  239. return _value
  240. def parse_int_from_env(key, default=None):
  241. try:
  242. value = os.environ[key]
  243. except KeyError:
  244. _value = default
  245. else:
  246. try:
  247. _value = int(value)
  248. except ValueError:
  249. raise ValueError(f"If set, {key} must be a int.")
  250. return _value
  251. _run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
  252. _run_flaky_tests = parse_flag_from_env("RUN_FLAKY", default=True)
  253. _run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False)
  254. _run_staging = parse_flag_from_env("HUGGINGFACE_CO_STAGING", default=False)
  255. _run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=True)
  256. _run_agent_tests = parse_flag_from_env("RUN_AGENT_TESTS", default=False)
  257. _run_training_tests = parse_flag_from_env("RUN_TRAINING_TESTS", default=True)
  258. _run_tensor_parallel_tests = parse_flag_from_env("RUN_TENSOR_PARALLEL_TESTS", default=True)
  259. def is_staging_test(test_case):
  260. """
  261. Decorator marking a test as a staging test.
  262. Those tests will run using the staging environment of huggingface.co instead of the real model hub.
  263. """
  264. if not _run_staging:
  265. return unittest.skip(reason="test is staging test")(test_case)
  266. else:
  267. try:
  268. import pytest # We don't need a hard dependency on pytest in the main library
  269. except ImportError:
  270. return test_case
  271. else:
  272. return pytest.mark.is_staging_test()(test_case)
  273. def is_pipeline_test(test_case):
  274. """
  275. Decorator marking a test as a pipeline test. If RUN_PIPELINE_TESTS is set to a falsy value, those tests will be
  276. skipped.
  277. """
  278. if not _run_pipeline_tests:
  279. return unittest.skip(reason="test is pipeline test")(test_case)
  280. else:
  281. try:
  282. import pytest # We don't need a hard dependency on pytest in the main library
  283. except ImportError:
  284. return test_case
  285. else:
  286. return pytest.mark.is_pipeline_test()(test_case)
  287. def is_agent_test(test_case):
  288. """
  289. Decorator marking a test as an agent test. If RUN_TOOL_TESTS is set to a falsy value, those tests will be skipped.
  290. """
  291. if not _run_agent_tests:
  292. return unittest.skip(reason="test is an agent test")(test_case)
  293. else:
  294. try:
  295. import pytest # We don't need a hard dependency on pytest in the main library
  296. except ImportError:
  297. return test_case
  298. else:
  299. return pytest.mark.is_agent_test()(test_case)
  300. def is_training_test(test_case):
  301. """
  302. Decorator marking a test as a training test. If RUN_TRAINING_TESTS is set to a falsy value, those tests will be
  303. skipped.
  304. """
  305. if not _run_training_tests:
  306. return unittest.skip(reason="test is training test")(test_case)
  307. else:
  308. try:
  309. import pytest # We don't need a hard dependency on pytest in the main library
  310. except ImportError:
  311. return test_case
  312. else:
  313. return pytest.mark.is_training_test()(test_case)
  314. def is_tensor_parallel_test(test_case):
  315. """
  316. Decorator marking a test as a tensor parallel test. If RUN_TENSOR_PARALLEL_TESTS is set to a falsy value, those
  317. tests will be skipped.
  318. """
  319. if not _run_tensor_parallel_tests:
  320. return unittest.skip(reason="test is tensor parallel test")(test_case)
  321. else:
  322. try:
  323. import pytest # We don't need a hard dependency on pytest in the main library
  324. except ImportError:
  325. return test_case
  326. else:
  327. return pytest.mark.is_tensor_parallel_test()(test_case)
  328. def slow(test_case):
  329. """
  330. Decorator marking a test as slow.
  331. Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
  332. """
  333. return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
  334. def tooslow(test_case):
  335. """
  336. Decorator marking a test as too slow.
  337. Slow tests are skipped while they're in the process of being fixed. No test should stay tagged as "tooslow" as
  338. these will not be tested by the CI.
  339. """
  340. return unittest.skip(reason="test is too slow")(test_case)
  341. def skip_if_not_implemented(test_func):
  342. @functools.wraps(test_func)
  343. def wrapper(*args, **kwargs):
  344. try:
  345. return test_func(*args, **kwargs)
  346. except NotImplementedError as e:
  347. raise unittest.SkipTest(f"Test skipped due to NotImplementedError: {e}")
  348. return wrapper
  349. def apply_skip_if_not_implemented(cls):
  350. """
  351. Class decorator to apply @skip_if_not_implemented to all test methods.
  352. """
  353. for attr_name in dir(cls):
  354. if attr_name.startswith("test_"):
  355. attr = getattr(cls, attr_name)
  356. if callable(attr):
  357. setattr(cls, attr_name, skip_if_not_implemented(attr))
  358. return cls
  359. def custom_tokenizers(test_case):
  360. """
  361. Decorator marking a test for a custom tokenizer.
  362. Custom tokenizers require additional dependencies, and are skipped by default. Set the RUN_CUSTOM_TOKENIZERS
  363. environment variable to a truthy value to run them.
  364. """
  365. return unittest.skipUnless(_run_custom_tokenizers, "test of custom tokenizers")(test_case)
  366. def require_bs4(test_case):
  367. """
  368. Decorator marking a test that requires BeautifulSoup4. These tests are skipped when BeautifulSoup4 isn't installed.
  369. """
  370. return unittest.skipUnless(is_bs4_available(), "test requires BeautifulSoup4")(test_case)
  371. def require_galore_torch(test_case):
  372. """
  373. Decorator marking a test that requires GaLore. These tests are skipped when GaLore isn't installed.
  374. https://github.com/jiaweizzhao/GaLore
  375. """
  376. return unittest.skipUnless(is_galore_torch_available(), "test requires GaLore")(test_case)
  377. def require_apollo_torch(test_case):
  378. """
  379. Decorator marking a test that requires GaLore. These tests are skipped when APOLLO isn't installed.
  380. https://github.com/zhuhanqing/APOLLO
  381. """
  382. return unittest.skipUnless(is_apollo_torch_available(), "test requires APOLLO")(test_case)
  383. def require_torch_optimi(test_case):
  384. """
  385. Decorator marking a test that requires torch-optimi. These tests are skipped when torch-optimi isn't installed.
  386. https://github.com/jxnl/torch-optimi
  387. """
  388. return unittest.skipUnless(is_torch_optimi_available(), "test requires torch-optimi")(test_case)
  389. def require_lomo(test_case):
  390. """
  391. Decorator marking a test that requires LOMO. These tests are skipped when LOMO-optim isn't installed.
  392. https://github.com/OpenLMLab/LOMO
  393. """
  394. return unittest.skipUnless(is_lomo_available(), "test requires LOMO")(test_case)
  395. def require_grokadamw(test_case):
  396. """
  397. Decorator marking a test that requires GrokAdamW. These tests are skipped when GrokAdamW isn't installed.
  398. """
  399. return unittest.skipUnless(is_grokadamw_available(), "test requires GrokAdamW")(test_case)
  400. def require_schedulefree(test_case):
  401. """
  402. Decorator marking a test that requires schedulefree. These tests are skipped when schedulefree isn't installed.
  403. https://github.com/facebookresearch/schedule_free
  404. """
  405. return unittest.skipUnless(is_schedulefree_available(), "test requires schedulefree")(test_case)
  406. def require_cv2(test_case):
  407. """
  408. Decorator marking a test that requires OpenCV.
  409. These tests are skipped when OpenCV isn't installed.
  410. """
  411. return unittest.skipUnless(is_cv2_available(), "test requires OpenCV")(test_case)
  412. def require_levenshtein(test_case):
  413. """
  414. Decorator marking a test that requires Levenshtein.
  415. These tests are skipped when Levenshtein isn't installed.
  416. """
  417. return unittest.skipUnless(is_levenshtein_available(), "test requires Levenshtein")(test_case)
  418. def require_nltk(test_case):
  419. """
  420. Decorator marking a test that requires NLTK.
  421. These tests are skipped when NLTK isn't installed.
  422. """
  423. return unittest.skipUnless(is_nltk_available(), "test requires NLTK")(test_case)
  424. def require_accelerate(test_case, min_version: str = ACCELERATE_MIN_VERSION):
  425. """
  426. Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
  427. """
  428. return unittest.skipUnless(
  429. is_accelerate_available(min_version), f"test requires accelerate version >= {min_version}"
  430. )(test_case)
  431. def require_triton(min_version: str = TRITON_MIN_VERSION):
  432. """
  433. Decorator marking a test that requires triton. These tests are skipped when triton isn't installed.
  434. """
  435. def decorator(test_case):
  436. return unittest.skipUnless(is_triton_available(min_version), f"test requires triton version >= {min_version}")(
  437. test_case
  438. )
  439. return decorator
  440. def require_gguf(test_case, min_version: str = GGUF_MIN_VERSION):
  441. """
  442. Decorator marking a test that requires ggguf. These tests are skipped when gguf isn't installed.
  443. """
  444. return unittest.skipUnless(is_gguf_available(min_version), f"test requires gguf version >= {min_version}")(
  445. test_case
  446. )
  447. def require_fsdp(test_case, min_version: str = "1.12.0"):
  448. """
  449. Decorator marking a test that requires fsdp. These tests are skipped when fsdp isn't installed.
  450. """
  451. return unittest.skipUnless(is_fsdp_available(min_version), f"test requires torch version >= {min_version}")(
  452. test_case
  453. )
  454. def require_g2p_en(test_case):
  455. """
  456. Decorator marking a test that requires g2p_en. These tests are skipped when SentencePiece isn't installed.
  457. """
  458. return unittest.skipUnless(is_g2p_en_available(), "test requires g2p_en")(test_case)
  459. def require_rjieba(test_case):
  460. """
  461. Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed.
  462. """
  463. return unittest.skipUnless(is_rjieba_available(), "test requires rjieba")(test_case)
  464. def require_jinja(test_case):
  465. """
  466. Decorator marking a test that requires jinja. These tests are skipped when jinja isn't installed.
  467. """
  468. return unittest.skipUnless(is_jinja_available(), "test requires jinja")(test_case)
  469. def require_jmespath(test_case):
  470. """
  471. Decorator marking a test that requires jmespath. These tests are skipped when jmespath isn't installed.
  472. """
  473. return unittest.skipUnless(is_jmespath_available(), "test requires jmespath")(test_case)
  474. def require_onnx(test_case):
  475. return unittest.skipUnless(is_onnx_available(), "test requires ONNX")(test_case)
  476. def require_timm(test_case):
  477. """
  478. Decorator marking a test that requires Timm.
  479. These tests are skipped when Timm isn't installed.
  480. """
  481. return unittest.skipUnless(is_timm_available(), "test requires Timm")(test_case)
  482. def require_natten(test_case):
  483. """
  484. Decorator marking a test that requires NATTEN.
  485. These tests are skipped when NATTEN isn't installed.
  486. """
  487. return unittest.skipUnless(is_natten_available(), "test requires natten")(test_case)
  488. def require_torch(test_case):
  489. """
  490. Decorator marking a test that requires PyTorch.
  491. These tests are skipped when PyTorch isn't installed.
  492. """
  493. return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
  494. def require_torch_greater_or_equal(version: str):
  495. """
  496. Decorator marking a test that requires PyTorch version >= `version`.
  497. These tests are skipped when PyTorch version is less than `version`.
  498. """
  499. def decorator(test_case):
  500. return unittest.skipUnless(is_torch_greater_or_equal(version), f"test requires PyTorch version >= {version}")(
  501. test_case
  502. )
  503. return decorator
  504. def require_huggingface_hub_greater_or_equal(version: str):
  505. """
  506. Decorator marking a test that requires huggingface_hub version >= `version`.
  507. These tests are skipped when huggingface_hub version is less than `version`.
  508. """
  509. def decorator(test_case):
  510. return unittest.skipUnless(
  511. is_huggingface_hub_greater_or_equal(version), f"test requires huggingface_hub version >= {version}"
  512. )(test_case)
  513. return decorator
  514. def require_flash_attn(test_case):
  515. """
  516. Decorator marking a test that requires Flash Attention.
  517. These tests are skipped when Flash Attention isn't installed.
  518. """
  519. flash_attn_available = is_flash_attn_2_available()
  520. kernels_available = is_kernels_available()
  521. try:
  522. from kernels import get_kernel
  523. get_kernel(FLASH_ATTN_KERNEL_FALLBACK["flash_attention_2"])
  524. except Exception as _:
  525. kernels_available = False
  526. return unittest.skipUnless(kernels_available | flash_attn_available, "test requires Flash Attention")(test_case)
  527. def require_kernels(test_case):
  528. """
  529. Decorator marking a test that requires the kernels library.
  530. These tests are skipped when the kernels library isn't installed.
  531. """
  532. return unittest.skipUnless(is_kernels_available(), "test requires the kernels library")(test_case)
  533. def require_flash_attn_3(test_case):
  534. """
  535. Decorator marking a test that requires Flash Attention 3.
  536. These tests are skipped when Flash Attention 3 isn't installed.
  537. """
  538. return unittest.skipUnless(is_flash_attn_3_available(), "test requires Flash Attention 3")(test_case)
  539. def require_flash_attn_4(test_case):
  540. """
  541. Decorator marking a test that requires Flash Attention 4.
  542. These tests are skipped when Flash Attention 4 isn't installed.
  543. """
  544. return unittest.skipUnless(is_flash_attn_4_available(), "test requires Flash Attention 4")(test_case)
  545. def require_all_flash_attn(test_case):
  546. flash_attn_available = is_flash_attn_2_available()
  547. kernels_available = is_kernels_available()
  548. try:
  549. from kernels import get_kernel
  550. get_kernel(FLASH_ATTN_KERNEL_FALLBACK["flash_attention_2"])
  551. except Exception as _:
  552. kernels_available = False
  553. return unittest.skipUnless(
  554. all(
  555. (
  556. flash_attn_available | kernels_available,
  557. is_flash_attn_3_available(),
  558. is_flash_attn_4_available(),
  559. )
  560. ),
  561. "test requires all mainline Flash Attention packages",
  562. )(test_case)
  563. def require_peft(test_case):
  564. """
  565. Decorator marking a test that requires PEFT.
  566. These tests are skipped when PEFT isn't installed.
  567. """
  568. return unittest.skipUnless(is_peft_available(), "test requires PEFT")(test_case)
  569. def require_torchvision(test_case):
  570. """
  571. Decorator marking a test that requires Torchvision.
  572. These tests are skipped when Torchvision isn't installed.
  573. """
  574. return unittest.skipUnless(is_torchvision_available(), "test requires Torchvision")(test_case)
  575. def require_torchcodec(test_case):
  576. """
  577. Decorator marking a test that requires Torchcodec.
  578. These tests are skipped when Torchcodec isn't installed.
  579. """
  580. return unittest.skipUnless(is_torchcodec_available(), "test requires Torchcodec")(test_case)
  581. def require_torchaudio(test_case):
  582. """
  583. Decorator marking a test that requires torchaudio. These tests are skipped when torchaudio isn't installed.
  584. """
  585. return unittest.skipUnless(is_torchaudio_available(), "test requires torchaudio")(test_case)
  586. def require_sentencepiece(test_case):
  587. """
  588. Decorator marking a test that requires SentencePiece. These tests are skipped when SentencePiece isn't installed.
  589. """
  590. return unittest.skipUnless(is_sentencepiece_available(), "test requires SentencePiece")(test_case)
  591. def require_sacremoses(test_case):
  592. """
  593. Decorator marking a test that requires Sacremoses. These tests are skipped when Sacremoses isn't installed.
  594. """
  595. return unittest.skipUnless(is_sacremoses_available(), "test requires Sacremoses")(test_case)
  596. def require_seqio(test_case):
  597. """
  598. Decorator marking a test that requires SentencePiece. These tests are skipped when SentencePiece isn't installed.
  599. """
  600. return unittest.skipUnless(is_seqio_available(), "test requires Seqio")(test_case)
  601. def require_scipy(test_case):
  602. """
  603. Decorator marking a test that requires Scipy. These tests are skipped when SentencePiece isn't installed.
  604. """
  605. return unittest.skipUnless(is_scipy_available(), "test requires Scipy")(test_case)
  606. def require_tokenizers(test_case):
  607. """
  608. Decorator marking a test that requires 🤗 Tokenizers. These tests are skipped when 🤗 Tokenizers isn't installed.
  609. """
  610. return unittest.skipUnless(is_tokenizers_available(), "test requires tokenizers")(test_case)
  611. def require_pandas(test_case):
  612. """
  613. Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed.
  614. """
  615. return unittest.skipUnless(is_pandas_available(), "test requires pandas")(test_case)
  616. def require_pytesseract(test_case):
  617. """
  618. Decorator marking a test that requires PyTesseract. These tests are skipped when PyTesseract isn't installed.
  619. """
  620. return unittest.skipUnless(is_pytesseract_available(), "test requires PyTesseract")(test_case)
  621. def require_pytorch_quantization(test_case):
  622. """
  623. Decorator marking a test that requires PyTorch Quantization Toolkit. These tests are skipped when PyTorch
  624. Quantization Toolkit isn't installed.
  625. """
  626. return unittest.skipUnless(is_pytorch_quantization_available(), "test requires PyTorch Quantization Toolkit")(
  627. test_case
  628. )
  629. def require_vision(test_case):
  630. """
  631. Decorator marking a test that requires the vision dependencies. These tests are skipped when torchaudio isn't
  632. installed.
  633. """
  634. return unittest.skipUnless(is_vision_available(), "test requires vision")(test_case)
  635. def require_spacy(test_case):
  636. """
  637. Decorator marking a test that requires SpaCy. These tests are skipped when SpaCy isn't installed.
  638. """
  639. return unittest.skipUnless(is_spacy_available(), "test requires spacy")(test_case)
  640. def require_torch_multi_gpu(test_case):
  641. """
  642. Decorator marking a test that requires a multi-GPU CUDA setup (in PyTorch). These tests are skipped on a machine without
  643. multiple CUDA GPUs.
  644. To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests -k "multi_gpu"
  645. """
  646. if not is_torch_available():
  647. return unittest.skip(reason="test requires PyTorch")(test_case)
  648. import torch
  649. return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple CUDA GPUs")(test_case)
  650. def require_torch_multi_accelerator(test_case):
  651. """
  652. Decorator marking a test that requires a multi-accelerator (in PyTorch). These tests are skipped on a machine
  653. without multiple accelerators. To run *only* the multi_accelerator tests, assuming all test names contain
  654. multi_accelerator: $ pytest -sv ./tests -k "multi_accelerator"
  655. """
  656. if not is_torch_available():
  657. return unittest.skip(reason="test requires PyTorch")(test_case)
  658. return unittest.skipUnless(backend_device_count(torch_device) > 1, "test requires multiple accelerators")(
  659. test_case
  660. )
  661. def require_torch_non_multi_gpu(test_case):
  662. """
  663. Decorator marking a test that requires 0 or 1 GPU setup (in PyTorch).
  664. """
  665. if not is_torch_available():
  666. return unittest.skip(reason="test requires PyTorch")(test_case)
  667. import torch
  668. return unittest.skipUnless(torch.cuda.device_count() < 2, "test requires 0 or 1 GPU")(test_case)
  669. def require_torch_non_multi_accelerator(test_case):
  670. """
  671. Decorator marking a test that requires 0 or 1 accelerator setup (in PyTorch).
  672. """
  673. if not is_torch_available():
  674. return unittest.skip(reason="test requires PyTorch")(test_case)
  675. return unittest.skipUnless(backend_device_count(torch_device) < 2, "test requires 0 or 1 accelerator")(test_case)
  676. def require_torch_up_to_2_gpus(test_case):
  677. """
  678. Decorator marking a test that requires 0 or 1 or 2 GPU setup (in PyTorch).
  679. """
  680. if not is_torch_available():
  681. return unittest.skip(reason="test requires PyTorch")(test_case)
  682. import torch
  683. return unittest.skipUnless(torch.cuda.device_count() < 3, "test requires 0 or 1 or 2 GPUs")(test_case)
  684. def require_torch_up_to_2_accelerators(test_case):
  685. """
  686. Decorator marking a test that requires 0 or 1 or 2 accelerator setup (in PyTorch).
  687. """
  688. if not is_torch_available():
  689. return unittest.skip(reason="test requires PyTorch")(test_case)
  690. return unittest.skipUnless(backend_device_count(torch_device) < 3, "test requires 0 or 1 or 2 accelerators")(
  691. test_case
  692. )
  693. def require_torch_xla(test_case):
  694. """
  695. Decorator marking a test that requires TorchXLA (in PyTorch).
  696. """
  697. return unittest.skipUnless(is_torch_xla_available(), "test requires TorchXLA")(test_case)
  698. def require_torch_neuroncore(test_case):
  699. """
  700. Decorator marking a test that requires NeuronCore (in PyTorch).
  701. """
  702. return unittest.skipUnless(is_torch_neuroncore_available(check_device=False), "test requires PyTorch NeuronCore")(
  703. test_case
  704. )
  705. def require_torch_npu(test_case):
  706. """
  707. Decorator marking a test that requires NPU (in PyTorch).
  708. """
  709. return unittest.skipUnless(is_torch_npu_available(), "test requires PyTorch NPU")(test_case)
  710. def require_torch_multi_npu(test_case):
  711. """
  712. Decorator marking a test that requires a multi-NPU setup (in PyTorch). These tests are skipped on a machine without
  713. multiple NPUs.
  714. To run *only* the multi_npu tests, assuming all test names contain multi_npu: $ pytest -sv ./tests -k "multi_npu"
  715. """
  716. if not is_torch_npu_available():
  717. return unittest.skip(reason="test requires PyTorch NPU")(test_case)
  718. return unittest.skipUnless(torch.npu.device_count() > 1, "test requires multiple NPUs")(test_case)
  719. def require_non_hpu(test_case):
  720. """
  721. Decorator marking a test that should be skipped for HPU.
  722. """
  723. return unittest.skipUnless(torch_device != "hpu", "test requires a non-HPU")(test_case)
  724. def require_torch_xpu(test_case):
  725. """
  726. Decorator marking a test that requires XPU (in PyTorch).
  727. These tests are skipped when XPU backend is not available.
  728. """
  729. return unittest.skipUnless(is_torch_xpu_available(), "test requires XPU device")(test_case)
  730. def require_non_xpu(test_case):
  731. """
  732. Decorator marking a test that should be skipped for XPU.
  733. """
  734. return unittest.skipUnless(torch_device != "xpu", "test requires a non-XPU")(test_case)
  735. def require_torch_multi_xpu(test_case):
  736. """
  737. Decorator marking a test that requires a multi-XPU setup (in PyTorch). These tests are skipped on a machine without
  738. multiple XPUs.
  739. To run *only* the multi_xpu tests, assuming all test names contain multi_xpu: $ pytest -sv ./tests -k "multi_xpu"
  740. """
  741. if not is_torch_xpu_available():
  742. return unittest.skip(reason="test requires PyTorch XPU")(test_case)
  743. return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case)
  744. def require_torch_multi_hpu(test_case):
  745. """
  746. Decorator marking a test that requires a multi-HPU setup (in PyTorch). These tests are skipped on a machine without
  747. multiple HPUs.
  748. To run *only* the multi_hpu tests, assuming all test names contain multi_hpu: $ pytest -sv ./tests -k "multi_hpu"
  749. """
  750. if not is_torch_hpu_available():
  751. return unittest.skip(reason="test requires PyTorch HPU")(test_case)
  752. return unittest.skipUnless(torch.hpu.device_count() > 1, "test requires multiple HPUs")(test_case)
  753. if is_torch_available():
  754. # Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
  755. import torch
  756. if "TRANSFORMERS_TEST_BACKEND" in os.environ:
  757. backend = os.environ["TRANSFORMERS_TEST_BACKEND"]
  758. try:
  759. _ = importlib.import_module(backend)
  760. except ModuleNotFoundError as e:
  761. raise ModuleNotFoundError(
  762. f"Failed to import `TRANSFORMERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module. The original error (look up to see its"
  763. f" traceback):\n{e}"
  764. ) from e
  765. if "TRANSFORMERS_TEST_DEVICE" in os.environ:
  766. torch_device = os.environ["TRANSFORMERS_TEST_DEVICE"]
  767. if torch_device == "cuda" and not torch.cuda.is_available():
  768. raise ValueError(
  769. f"TRANSFORMERS_TEST_DEVICE={torch_device}, but CUDA is unavailable. Please double-check your testing environment."
  770. )
  771. if torch_device == "xpu" and not is_torch_xpu_available():
  772. raise ValueError(
  773. f"TRANSFORMERS_TEST_DEVICE={torch_device}, but XPU is unavailable. Please double-check your testing environment."
  774. )
  775. if torch_device == "npu" and not is_torch_npu_available():
  776. raise ValueError(
  777. f"TRANSFORMERS_TEST_DEVICE={torch_device}, but NPU is unavailable. Please double-check your testing environment."
  778. )
  779. if torch_device == "mlu" and not is_torch_mlu_available():
  780. raise ValueError(
  781. f"TRANSFORMERS_TEST_DEVICE={torch_device}, but MLU is unavailable. Please double-check your testing environment."
  782. )
  783. if torch_device == "hpu" and not is_torch_hpu_available():
  784. raise ValueError(
  785. f"TRANSFORMERS_TEST_DEVICE={torch_device}, but HPU is unavailable. Please double-check your testing environment."
  786. )
  787. try:
  788. # try creating device to see if provided device is valid
  789. _ = torch.device(torch_device)
  790. except RuntimeError as e:
  791. raise RuntimeError(
  792. f"Unknown testing device specified by environment variable `TRANSFORMERS_TEST_DEVICE`: {torch_device}"
  793. ) from e
  794. elif torch.cuda.is_available():
  795. torch_device = "cuda"
  796. elif is_torch_npu_available():
  797. torch_device = "npu"
  798. elif is_torch_mlu_available():
  799. torch_device = "mlu"
  800. elif is_torch_hpu_available():
  801. torch_device = "hpu"
  802. elif is_torch_xpu_available():
  803. torch_device = "xpu"
  804. else:
  805. torch_device = "cpu"
  806. else:
  807. torch_device = None
  808. def require_torchao(test_case):
  809. """Decorator marking a test that requires torchao"""
  810. return unittest.skipUnless(is_torchao_available(), "test requires torchao")(test_case)
  811. def require_torchao_version_greater_or_equal(torchao_version):
  812. def decorator(test_case):
  813. correct_torchao_version = is_torchao_available() and version.parse(
  814. version.parse(importlib.metadata.version("torchao")).base_version
  815. ) >= version.parse(torchao_version)
  816. return unittest.skipUnless(
  817. correct_torchao_version, f"Test requires torchao with the version greater than {torchao_version}."
  818. )(test_case)
  819. return decorator
  820. def require_torch_tensorrt_fx(test_case):
  821. """Decorator marking a test that requires Torch-TensorRT FX"""
  822. return unittest.skipUnless(is_torch_tensorrt_fx_available(), "test requires Torch-TensorRT FX")(test_case)
  823. def require_torch_gpu(test_case):
  824. """Decorator marking a test that requires CUDA and PyTorch."""
  825. return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
  826. def require_torch_mps(test_case):
  827. """Decorator marking a test that requires CUDA and PyTorch."""
  828. return unittest.skipUnless(torch_device == "mps", "test requires MPS")(test_case)
  829. def require_large_cpu_ram(test_case, memory: float = 80):
  830. """Decorator marking a test that requires a CPU RAM with more than `memory` GiB of memory."""
  831. if not is_psutil_available():
  832. return test_case
  833. import psutil
  834. return unittest.skipUnless(
  835. psutil.virtual_memory().total / 1024**3 > memory,
  836. f"test requires a machine with more than {memory} GiB of CPU RAM memory",
  837. )(test_case)
  838. def require_torch_large_gpu(test_case, memory: float = 20):
  839. """Decorator marking a test that requires a CUDA GPU with more than `memory` GiB of memory."""
  840. if torch_device != "cuda":
  841. return unittest.skip(reason=f"test requires a CUDA GPU with more than {memory} GiB of memory")(test_case)
  842. return unittest.skipUnless(
  843. torch.cuda.get_device_properties(0).total_memory / 1024**3 > memory,
  844. f"test requires a GPU with more than {memory} GiB of memory",
  845. )(test_case)
  846. def require_torch_large_accelerator(test_case=None, *, memory: float = 20):
  847. """Decorator marking a test that requires an accelerator with more than `memory` GiB of memory."""
  848. def memory_decorator(tc):
  849. if torch_device not in ("cuda", "xpu"):
  850. return unittest.skip(f"test requires a GPU or XPU with more than {memory} GiB of memory")(tc)
  851. torch_accel = getattr(torch, torch_device)
  852. return unittest.skipUnless(
  853. torch_accel.get_device_properties(0).total_memory / 1024**3 > memory,
  854. f"test requires a GPU or XPU with more than {memory} GiB of memory",
  855. )(tc)
  856. return memory_decorator if test_case is None else memory_decorator(test_case)
  857. def require_torch_accelerator(test_case):
  858. """Decorator marking a test that requires an accessible accelerator and PyTorch."""
  859. return unittest.skipUnless(torch_device is not None and torch_device != "cpu", "test requires accelerator")(
  860. test_case
  861. )
  862. def require_torch_fp16(test_case):
  863. """Decorator marking a test that requires a device that supports fp16"""
  864. return unittest.skipUnless(
  865. is_torch_fp16_available_on_device(torch_device), "test requires device with fp16 support"
  866. )(test_case)
  867. def require_fp8(test_case):
  868. """Decorator marking a test that requires supports for fp8"""
  869. return unittest.skipUnless(is_accelerate_available() and is_fp8_available(), "test requires fp8 support")(
  870. test_case
  871. )
  872. def require_cuda_capability_at_least(major, minor):
  873. """Decorator to skip tests when CUDA capability is below the given version."""
  874. import torch
  875. if not torch.cuda.is_available():
  876. return unittest.skip("CUDA not available")
  877. capability = torch.cuda.get_device_capability()
  878. return unittest.skipIf(capability < (major, minor), f"Requires CUDA capability >= {major}.{minor}")
  879. def require_torch_bf16(test_case):
  880. """Decorator marking a test that requires a device that supports bf16"""
  881. return unittest.skipUnless(
  882. is_torch_bf16_available_on_device(torch_device), "test requires device with bf16 support"
  883. )(test_case)
  884. def require_deterministic_for_xpu(test_case):
  885. @wraps(test_case)
  886. def wrapper(*args, **kwargs):
  887. if is_torch_xpu_available():
  888. original_state = torch.are_deterministic_algorithms_enabled()
  889. try:
  890. torch.use_deterministic_algorithms(True)
  891. return test_case(*args, **kwargs)
  892. finally:
  893. torch.use_deterministic_algorithms(original_state)
  894. else:
  895. return test_case(*args, **kwargs)
  896. return wrapper
  897. def require_torch_tf32(test_case):
  898. """Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7."""
  899. return unittest.skipUnless(
  900. is_torch_tf32_available(), "test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7"
  901. )(test_case)
  902. def require_detectron2(test_case):
  903. """Decorator marking a test that requires detectron2."""
  904. return unittest.skipUnless(is_detectron2_available(), "test requires `detectron2`")(test_case)
  905. def require_faiss(test_case):
  906. """Decorator marking a test that requires faiss."""
  907. return unittest.skipUnless(is_faiss_available(), "test requires `faiss`")(test_case)
  908. def require_optuna(test_case):
  909. """
  910. Decorator marking a test that requires optuna.
  911. These tests are skipped when optuna isn't installed.
  912. """
  913. return unittest.skipUnless(is_optuna_available(), "test requires optuna")(test_case)
  914. def require_ray(test_case):
  915. """
  916. Decorator marking a test that requires Ray/tune.
  917. These tests are skipped when Ray/tune isn't installed.
  918. """
  919. return unittest.skipUnless(is_ray_available(), "test requires Ray/tune")(test_case)
  920. def require_swanlab(test_case):
  921. """
  922. Decorator marking a test that requires swanlab.
  923. These tests are skipped when swanlab isn't installed.
  924. """
  925. return unittest.skipUnless(is_swanlab_available(), "test requires swanlab")(test_case)
  926. def require_trackio(test_case):
  927. """
  928. Decorator marking a test that requires trackio.
  929. These tests are skipped when trackio isn't installed.
  930. """
  931. return unittest.skipUnless(is_trackio_available(), "test requires trackio")(test_case)
  932. def require_wandb(test_case):
  933. """
  934. Decorator marking a test that requires wandb.
  935. These tests are skipped when wandb isn't installed.
  936. """
  937. return unittest.skipUnless(is_wandb_available(), "test requires wandb")(test_case)
  938. def require_clearml(test_case):
  939. """
  940. Decorator marking a test requires clearml.
  941. These tests are skipped when clearml isn't installed.
  942. """
  943. return unittest.skipUnless(is_clearml_available(), "test requires clearml")(test_case)
  944. def require_deepspeed(test_case):
  945. """
  946. Decorator marking a test that requires deepspeed
  947. """
  948. return unittest.skipUnless(is_deepspeed_available(), "test requires deepspeed")(test_case)
  949. def require_apex(test_case):
  950. """
  951. Decorator marking a test that requires apex
  952. """
  953. return unittest.skipUnless(is_apex_available(), "test requires apex")(test_case)
  954. def require_aqlm(test_case):
  955. """
  956. Decorator marking a test that requires aqlm
  957. """
  958. return unittest.skipUnless(is_aqlm_available(), "test requires aqlm")(test_case)
  959. def require_vptq(test_case):
  960. """
  961. Decorator marking a test that requires vptq
  962. """
  963. return unittest.skipUnless(is_vptq_available(), "test requires vptq")(test_case)
  964. def require_spqr(test_case):
  965. """
  966. Decorator marking a test that requires spqr
  967. """
  968. return unittest.skipUnless(is_spqr_available(), "test requires spqr")(test_case)
  969. def require_av(test_case):
  970. """
  971. Decorator marking a test that requires av
  972. """
  973. return unittest.skipUnless(is_av_available(), "test requires av")(test_case)
  974. def require_decord(test_case):
  975. """
  976. Decorator marking a test that requires decord
  977. """
  978. return unittest.skipUnless(is_decord_available(), "test requires decord")(test_case)
  979. def require_bitsandbytes(test_case):
  980. """
  981. Decorator marking a test that requires the bitsandbytes library. Will be skipped when the library or its hard dependency torch is not installed.
  982. """
  983. return unittest.skipUnless(is_bitsandbytes_available(), "test requires bitsandbytes")(test_case)
  984. def require_optimum(test_case):
  985. """
  986. Decorator for optimum dependency
  987. """
  988. return unittest.skipUnless(is_optimum_available(), "test requires optimum")(test_case)
  989. def require_tensorboard(test_case):
  990. """
  991. Decorator for `tensorboard` dependency
  992. """
  993. return unittest.skipUnless(is_tensorboard_available(), "test requires tensorboard")
  994. def require_gptqmodel(test_case):
  995. """
  996. Decorator for gptqmodel dependency
  997. """
  998. return unittest.skipUnless(is_gptqmodel_available(), "test requires gptqmodel")(test_case)
  999. def require_hqq(test_case):
  1000. """
  1001. Decorator for hqq dependency
  1002. """
  1003. return unittest.skipUnless(is_hqq_available(), "test requires hqq")(test_case)
  1004. def require_auto_round(test_case):
  1005. """
  1006. Decorator for auto_round dependency
  1007. """
  1008. return unittest.skipUnless(is_auto_round_available(), "test requires autoround")(test_case)
  1009. def require_optimum_quanto(test_case):
  1010. """
  1011. Decorator for quanto dependency
  1012. """
  1013. return unittest.skipUnless(is_optimum_quanto_available(), "test requires optimum-quanto")(test_case)
  1014. def require_compressed_tensors(test_case):
  1015. """
  1016. Decorator for compressed_tensors dependency
  1017. """
  1018. return unittest.skipUnless(is_compressed_tensors_available(), "test requires compressed_tensors")(test_case)
  1019. def require_fbgemm_gpu(test_case):
  1020. """
  1021. Decorator for fbgemm_gpu dependency
  1022. """
  1023. return unittest.skipUnless(is_fbgemm_gpu_available(), "test requires fbgemm-gpu")(test_case)
  1024. def require_quark(test_case):
  1025. """
  1026. Decorator for quark dependency
  1027. """
  1028. return unittest.skipUnless(is_quark_available(), "test requires quark")(test_case)
  1029. def require_flute_hadamard(test_case):
  1030. """
  1031. Decorator marking a test that requires higgs and hadamard
  1032. """
  1033. return unittest.skipUnless(
  1034. is_flute_available() and is_hadamard_available(), "test requires flute and fast_hadamard_transform"
  1035. )(test_case)
  1036. def require_fouroversix(test_case):
  1037. """
  1038. Decorator marking a test that requires fouroversix
  1039. """
  1040. return unittest.skipUnless(is_fouroversix_available(), "test requires fouroversix")(test_case)
  1041. def require_fp_quant(test_case):
  1042. """
  1043. Decorator marking a test that requires fp_quant and qutlass
  1044. """
  1045. return unittest.skipUnless(is_fp_quant_available(), "test requires fp_quant")(test_case)
  1046. def require_qutlass(test_case):
  1047. """
  1048. Decorator marking a test that requires qutlass
  1049. """
  1050. return unittest.skipUnless(is_qutlass_available(), "test requires qutlass")(test_case)
  1051. def require_phonemizer(test_case):
  1052. """
  1053. Decorator marking a test that requires phonemizer
  1054. """
  1055. return unittest.skipUnless(is_phonemizer_available(), "test requires phonemizer")(test_case)
  1056. def require_pyctcdecode(test_case):
  1057. """
  1058. Decorator marking a test that requires pyctcdecode
  1059. """
  1060. return unittest.skipUnless(is_pyctcdecode_available(), "test requires pyctcdecode")(test_case)
  1061. def require_numba(test_case):
  1062. """
  1063. Decorator marking a test that requires numba
  1064. """
  1065. return unittest.skipUnless(is_numba_available(), "test requires numba")(test_case)
  1066. def require_librosa(test_case):
  1067. """
  1068. Decorator marking a test that requires librosa
  1069. """
  1070. return unittest.skipUnless(is_librosa_available(), "test requires librosa")(test_case)
  1071. def require_multipart(test_case):
  1072. """
  1073. Decorator marking a test that requires python-multipart
  1074. """
  1075. return unittest.skipUnless(is_multipart_available(), "test requires python-multipart")(test_case)
  1076. def require_liger_kernel(test_case):
  1077. """
  1078. Decorator marking a test that requires liger_kernel
  1079. """
  1080. return unittest.skipUnless(is_liger_kernel_available(), "test requires liger_kernel")(test_case)
  1081. def require_essentia(test_case):
  1082. """
  1083. Decorator marking a test that requires essentia
  1084. """
  1085. return unittest.skipUnless(is_essentia_available(), "test requires essentia")(test_case)
  1086. def require_pretty_midi(test_case):
  1087. """
  1088. Decorator marking a test that requires pretty_midi
  1089. """
  1090. return unittest.skipUnless(is_pretty_midi_available(), "test requires pretty_midi")(test_case)
  1091. def cmd_exists(cmd):
  1092. return shutil.which(cmd) is not None
  1093. def require_usr_bin_time(test_case):
  1094. """
  1095. Decorator marking a test that requires `/usr/bin/time`
  1096. """
  1097. return unittest.skipUnless(cmd_exists("/usr/bin/time"), "test requires /usr/bin/time")(test_case)
  1098. def require_sudachi(test_case):
  1099. """
  1100. Decorator marking a test that requires sudachi
  1101. """
  1102. return unittest.skipUnless(is_sudachi_available(), "test requires sudachi")(test_case)
  1103. def require_sudachi_projection(test_case):
  1104. """
  1105. Decorator marking a test that requires sudachi_projection
  1106. """
  1107. return unittest.skipUnless(is_sudachi_projection_available(), "test requires sudachi which supports projection")(
  1108. test_case
  1109. )
  1110. def require_jumanpp(test_case):
  1111. """
  1112. Decorator marking a test that requires jumanpp
  1113. """
  1114. return unittest.skipUnless(is_jumanpp_available(), "test requires jumanpp")(test_case)
  1115. def require_cython(test_case):
  1116. """
  1117. Decorator marking a test that requires jumanpp
  1118. """
  1119. return unittest.skipUnless(is_cython_available(), "test requires cython")(test_case)
  1120. def require_tiktoken(test_case):
  1121. """
  1122. Decorator marking a test that requires TikToken. These tests are skipped when TikToken isn't installed.
  1123. """
  1124. return unittest.skipUnless(is_tiktoken_available(), "test requires TikToken")(test_case)
  1125. def require_speech(test_case):
  1126. """
  1127. Decorator marking a test that requires speech. These tests are skipped when speech isn't available.
  1128. """
  1129. return unittest.skipUnless(is_speech_available(), "test requires torchaudio")(test_case)
  1130. def require_openai(test_case):
  1131. """
  1132. Decorator marking a test that requires openai
  1133. """
  1134. return unittest.skipUnless(is_openai_available(), "test requires openai")(test_case)
  1135. def require_serve(test_case):
  1136. """
  1137. Decorator marking a test that requires the serving dependencies (fastapi, uvicorn, pydantic, openai).
  1138. """
  1139. return unittest.skipUnless(is_serve_available(), "test requires serving dependencies")(test_case)
  1140. def require_mistral_common(test_case):
  1141. """
  1142. Decorator marking a test that requires mistral-common. These tests are skipped when mistral-common isn't available.
  1143. """
  1144. return unittest.skipUnless(is_mistral_common_available(), "test requires mistral-common")(test_case)
  1145. def get_gpu_count():
  1146. """
  1147. Return the number of available gpus
  1148. """
  1149. if is_torch_available():
  1150. import torch
  1151. return torch.cuda.device_count()
  1152. else:
  1153. return 0
  1154. def get_tests_dir(append_path=None):
  1155. """
  1156. Args:
  1157. append_path: optional path to append to the tests dir path
  1158. Return:
  1159. The full path to the `tests` dir, so that the tests can be invoked from anywhere. Optionally `append_path` is
  1160. joined after the `tests` dir the former is provided.
  1161. """
  1162. # this function caller's __file__
  1163. caller__file__ = inspect.stack()[1][1]
  1164. tests_dir = os.path.abspath(os.path.dirname(caller__file__))
  1165. while not tests_dir.endswith("tests"):
  1166. tests_dir = os.path.dirname(tests_dir)
  1167. if append_path:
  1168. return os.path.join(tests_dir, append_path)
  1169. else:
  1170. return tests_dir
  1171. def get_steps_per_epoch(trainer: Trainer) -> int:
  1172. training_args = trainer.args
  1173. train_dataloader = trainer.get_train_dataloader()
  1174. initial_training_values = trainer.set_initial_training_values(args=training_args, dataloader=train_dataloader)
  1175. steps_per_epoch = initial_training_values[5]
  1176. return steps_per_epoch
  1177. def evaluate_side_effect_factory(
  1178. side_effect_values: list[dict[str, float]],
  1179. ) -> Generator[dict[str, float], None, None]:
  1180. """
  1181. Function that returns side effects for the _evaluate method.
  1182. Used when we're unsure of exactly how many times _evaluate will be called.
  1183. """
  1184. yield from side_effect_values
  1185. while True:
  1186. yield side_effect_values[-1]
  1187. #
  1188. # Helper functions for dealing with testing text outputs
  1189. # The original code came from:
  1190. # https://github.com/fastai/fastai/blob/master/tests/utils/text.py
  1191. # When any function contains print() calls that get overwritten, like progress bars,
  1192. # a special care needs to be applied, since under pytest -s captured output (capsys
  1193. # or contextlib.redirect_stdout) contains any temporary printed strings, followed by
  1194. # \r's. This helper function ensures that the buffer will contain the same output
  1195. # with and without -s in pytest, by turning:
  1196. # foo bar\r tar mar\r final message
  1197. # into:
  1198. # final message
  1199. # it can handle a single string or a multiline buffer
  1200. def apply_print_resets(buf):
  1201. return re.sub(r"^.*\r", "", buf, 0, re.MULTILINE)
  1202. def assert_screenout(out, what):
  1203. out_pr = apply_print_resets(out).lower()
  1204. match_str = out_pr.find(what.lower())
  1205. assert match_str != -1, f"expecting to find {what} in output: f{out_pr}"
  1206. def set_config_for_less_flaky_test(config):
  1207. target_attrs = [
  1208. "rms_norm_eps",
  1209. "layer_norm_eps",
  1210. "norm_eps",
  1211. "norm_epsilon",
  1212. "layer_norm_epsilon",
  1213. "batch_norm_eps",
  1214. ]
  1215. for target_attr in target_attrs:
  1216. setattr(config, target_attr, 1.0)
  1217. # norm layers (layer/group norm, etc.) could cause flaky tests when the tensors have very small variance.
  1218. # (We don't need the original epsilon values to check eager/sdpa matches)
  1219. attrs = ["text_config", "vision_config", "audio_config", "text_encoder", "audio_encoder", "decoder"]
  1220. for attr in attrs:
  1221. if hasattr(config, attr) and getattr(config, attr) is not None:
  1222. for target_attr in target_attrs:
  1223. setattr(getattr(config, attr), target_attr, 1.0)
  1224. def set_model_for_less_flaky_test(model):
  1225. # Another way to make sure norm layers have desired epsilon. (Some models don't set it from its config.)
  1226. target_names = (
  1227. "LayerNorm",
  1228. "GroupNorm",
  1229. "BatchNorm",
  1230. "RMSNorm",
  1231. "BatchNorm2d",
  1232. "BatchNorm1d",
  1233. "BitGroupNormActivation",
  1234. "WeightStandardizedConv2d",
  1235. )
  1236. target_attrs = ["eps", "epsilon", "variance_epsilon"]
  1237. if is_torch_available() and isinstance(model, torch.nn.Module):
  1238. for module in model.modules():
  1239. if type(module).__name__.endswith(target_names):
  1240. for attr in target_attrs:
  1241. if hasattr(module, attr):
  1242. setattr(module, attr, 1.0)
  1243. class CaptureStd:
  1244. """
  1245. Context manager to capture:
  1246. - stdout: replay it, clean it up and make it available via `obj.out`
  1247. - stderr: replay it and make it available via `obj.err`
  1248. Args:
  1249. out (`bool`, *optional*, defaults to `True`): Whether to capture stdout or not.
  1250. err (`bool`, *optional*, defaults to `True`): Whether to capture stderr or not.
  1251. replay (`bool`, *optional*, defaults to `True`): Whether to replay or not.
  1252. By default each captured stream gets replayed back on context's exit, so that one can see what the test was
  1253. doing. If this is a not wanted behavior and the captured data shouldn't be replayed, pass `replay=False` to
  1254. disable this feature.
  1255. Examples:
  1256. ```python
  1257. # to capture stdout only with auto-replay
  1258. with CaptureStdout() as cs:
  1259. print("Secret message")
  1260. assert "message" in cs.out
  1261. # to capture stderr only with auto-replay
  1262. import sys
  1263. with CaptureStderr() as cs:
  1264. print("Warning: ", file=sys.stderr)
  1265. assert "Warning" in cs.err
  1266. # to capture both streams with auto-replay
  1267. with CaptureStd() as cs:
  1268. print("Secret message")
  1269. print("Warning: ", file=sys.stderr)
  1270. assert "message" in cs.out
  1271. assert "Warning" in cs.err
  1272. # to capture just one of the streams, and not the other, with auto-replay
  1273. with CaptureStd(err=False) as cs:
  1274. print("Secret message")
  1275. assert "message" in cs.out
  1276. # but best use the stream-specific subclasses
  1277. # to capture without auto-replay
  1278. with CaptureStd(replay=False) as cs:
  1279. print("Secret message")
  1280. assert "message" in cs.out
  1281. ```"""
  1282. def __init__(self, out=True, err=True, replay=True):
  1283. self.replay = replay
  1284. if out:
  1285. self.out_buf = StringIO()
  1286. self.out = "error: CaptureStd context is unfinished yet, called too early"
  1287. else:
  1288. self.out_buf = None
  1289. self.out = "not capturing stdout"
  1290. if err:
  1291. self.err_buf = StringIO()
  1292. self.err = "error: CaptureStd context is unfinished yet, called too early"
  1293. else:
  1294. self.err_buf = None
  1295. self.err = "not capturing stderr"
  1296. def __enter__(self):
  1297. if self.out_buf:
  1298. self.out_old = sys.stdout
  1299. sys.stdout = self.out_buf
  1300. if self.err_buf:
  1301. self.err_old = sys.stderr
  1302. sys.stderr = self.err_buf
  1303. return self
  1304. def __exit__(self, *exc):
  1305. if self.out_buf:
  1306. sys.stdout = self.out_old
  1307. captured = self.out_buf.getvalue()
  1308. if self.replay:
  1309. sys.stdout.write(captured)
  1310. self.out = apply_print_resets(captured)
  1311. if self.err_buf:
  1312. sys.stderr = self.err_old
  1313. captured = self.err_buf.getvalue()
  1314. if self.replay:
  1315. sys.stderr.write(captured)
  1316. self.err = captured
  1317. def __repr__(self):
  1318. msg = ""
  1319. if self.out_buf:
  1320. msg += f"stdout: {self.out}\n"
  1321. if self.err_buf:
  1322. msg += f"stderr: {self.err}\n"
  1323. return msg
  1324. # in tests it's the best to capture only the stream that's wanted, otherwise
  1325. # it's easy to miss things, so unless you need to capture both streams, use the
  1326. # subclasses below (less typing). Or alternatively, configure `CaptureStd` to
  1327. # disable the stream you don't need to test.
  1328. class CaptureStdout(CaptureStd):
  1329. """Same as CaptureStd but captures only stdout"""
  1330. def __init__(self, replay=True):
  1331. super().__init__(err=False, replay=replay)
  1332. class CaptureStderr(CaptureStd):
  1333. """Same as CaptureStd but captures only stderr"""
  1334. def __init__(self, replay=True):
  1335. super().__init__(out=False, replay=replay)
  1336. class CaptureLogger:
  1337. """
  1338. Context manager to capture `logging` streams
  1339. Args:
  1340. logger: 'logging` logger object
  1341. Returns:
  1342. The captured output is available via `self.out`
  1343. Example:
  1344. ```python
  1345. >>> from transformers import logging
  1346. >>> from transformers.testing_utils import CaptureLogger
  1347. >>> msg = "Testing 1, 2, 3"
  1348. >>> logging.set_verbosity_info()
  1349. >>> logger = logging.get_logger("transformers.models.bart.tokenization_bart")
  1350. >>> with CaptureLogger(logger) as cl:
  1351. ... logger.info(msg)
  1352. >>> assert cl.out, msg + "\n"
  1353. ```
  1354. """
  1355. def __init__(self, logger):
  1356. self.logger = logger
  1357. self.io = StringIO()
  1358. self.sh = logging.StreamHandler(self.io)
  1359. self.out = ""
  1360. def __enter__(self):
  1361. self.logger.addHandler(self.sh)
  1362. return self
  1363. def __exit__(self, *exc):
  1364. self.logger.removeHandler(self.sh)
  1365. self.out = self.io.getvalue()
  1366. def __repr__(self):
  1367. return f"captured: {self.out}\n"
  1368. @contextlib.contextmanager
  1369. def LoggingLevel(level):
  1370. """
  1371. This is a context manager to temporarily change transformers modules logging level to the desired value and have it
  1372. restored to the original setting at the end of the scope.
  1373. Example:
  1374. ```python
  1375. with LoggingLevel(logging.INFO):
  1376. AutoModel.from_pretrained("openai-community/gpt2") # calls logger.info() several times
  1377. ```
  1378. """
  1379. orig_level = transformers_logging.get_verbosity()
  1380. try:
  1381. transformers_logging.set_verbosity(level)
  1382. yield
  1383. finally:
  1384. transformers_logging.set_verbosity(orig_level)
  1385. class TemporaryHubRepo:
  1386. """Create a temporary Hub repository and return its `RepoUrl` object. This is similar to
  1387. `tempfile.TemporaryDirectory` and can be used as a context manager. For example:
  1388. with TemporaryHubRepo(token=self._token) as temp_repo:
  1389. ...
  1390. Upon exiting the context, the repository and everything contained in it are removed.
  1391. Example:
  1392. ```python
  1393. with TemporaryHubRepo(token=self._token) as temp_repo:
  1394. model.push_to_hub(tmp_repo.repo_id, token=self._token)
  1395. ```
  1396. """
  1397. def __init__(self, namespace: str | None = None, token: str | None = None) -> None:
  1398. self.token = token
  1399. with tempfile.TemporaryDirectory() as tmp_dir:
  1400. repo_id = Path(tmp_dir).name
  1401. if namespace is not None:
  1402. repo_id = f"{namespace}/{repo_id}"
  1403. self.repo_url = create_repo(repo_id, token=self.token)
  1404. def __enter__(self):
  1405. return self.repo_url
  1406. def __exit__(self, exc, value, tb):
  1407. delete_repo(repo_id=self.repo_url.repo_id, token=self.token, missing_ok=True)
  1408. @contextlib.contextmanager
  1409. # adapted from https://stackoverflow.com/a/64789046/9201239
  1410. def ExtendSysPath(path: str | os.PathLike) -> Iterator[None]:
  1411. """
  1412. Temporary add given path to `sys.path`.
  1413. Usage :
  1414. ```python
  1415. with ExtendSysPath("/path/to/dir"):
  1416. mymodule = importlib.import_module("mymodule")
  1417. ```
  1418. """
  1419. path = os.fspath(path)
  1420. try:
  1421. sys.path.insert(0, path)
  1422. yield
  1423. finally:
  1424. sys.path.remove(path)
  1425. class TestCasePlus(unittest.TestCase):
  1426. """
  1427. This class extends *unittest.TestCase* with additional features.
  1428. Feature 1: A set of fully resolved important file and dir path accessors.
  1429. In tests often we need to know where things are relative to the current test file, and it's not trivial since the
  1430. test could be invoked from more than one directory or could reside in sub-directories with different depths. This
  1431. class solves this problem by sorting out all the basic paths and provides easy accessors to them:
  1432. - `pathlib` objects (all fully resolved):
  1433. - `test_file_path` - the current test file path (=`__file__`)
  1434. - `test_file_dir` - the directory containing the current test file
  1435. - `tests_dir` - the directory of the `tests` test suite
  1436. - `examples_dir` - the directory of the `examples` test suite
  1437. - `repo_root_dir` - the directory of the repository
  1438. - `src_dir` - the directory of `src` (i.e. where the `transformers` sub-dir resides)
  1439. - stringified paths---same as above but these return paths as strings, rather than `pathlib` objects:
  1440. - `test_file_path_str`
  1441. - `test_file_dir_str`
  1442. - `tests_dir_str`
  1443. - `examples_dir_str`
  1444. - `repo_root_dir_str`
  1445. - `src_dir_str`
  1446. Feature 2: Flexible auto-removable temporary dirs which are guaranteed to get removed at the end of test.
  1447. 1. Create a unique temporary dir:
  1448. ```python
  1449. def test_whatever(self):
  1450. tmp_dir = self.get_auto_remove_tmp_dir()
  1451. ```
  1452. `tmp_dir` will contain the path to the created temporary dir. It will be automatically removed at the end of the
  1453. test.
  1454. 2. Create a temporary dir of my choice, ensure it's empty before the test starts and don't
  1455. empty it after the test.
  1456. ```python
  1457. def test_whatever(self):
  1458. tmp_dir = self.get_auto_remove_tmp_dir("./xxx")
  1459. ```
  1460. This is useful for debug when you want to monitor a specific directory and want to make sure the previous tests
  1461. didn't leave any data in there.
  1462. 3. You can override the first two options by directly overriding the `before` and `after` args, leading to the
  1463. following behavior:
  1464. `before=True`: the temporary dir will always be cleared at the beginning of the test.
  1465. `before=False`: if the temporary dir already existed, any existing files will remain there.
  1466. `after=True`: the temporary dir will always be deleted at the end of the test.
  1467. `after=False`: the temporary dir will always be left intact at the end of the test.
  1468. Note 1: In order to run the equivalent of `rm -r` safely, only subdirs of the project repository checkout are
  1469. allowed if an explicit `tmp_dir` is used, so that by mistake no `/tmp` or similar important part of the filesystem
  1470. will get nuked. i.e. please always pass paths that start with `./`
  1471. Note 2: Each test can register multiple temporary dirs and they all will get auto-removed, unless requested
  1472. otherwise.
  1473. Feature 3: Get a copy of the `os.environ` object that sets up `PYTHONPATH` specific to the current test suite. This
  1474. is useful for invoking external programs from the test suite - e.g. distributed training.
  1475. ```python
  1476. def test_whatever(self):
  1477. env = self.get_env()
  1478. ```"""
  1479. def setUp(self):
  1480. # get_auto_remove_tmp_dir feature:
  1481. self.teardown_tmp_dirs = []
  1482. # figure out the resolved paths for repo_root, tests, examples, etc.
  1483. self._test_file_path = inspect.getfile(self.__class__)
  1484. path = Path(self._test_file_path).resolve()
  1485. self._test_file_dir = path.parents[0]
  1486. for up in [1, 2, 3]:
  1487. tmp_dir = path.parents[up]
  1488. if (tmp_dir / "src").is_dir() and (tmp_dir / "tests").is_dir():
  1489. break
  1490. if tmp_dir:
  1491. self._repo_root_dir = tmp_dir
  1492. else:
  1493. raise ValueError(f"can't figure out the root of the repo from {self._test_file_path}")
  1494. self._tests_dir = self._repo_root_dir / "tests"
  1495. self._examples_dir = self._repo_root_dir / "examples"
  1496. self._src_dir = self._repo_root_dir / "src"
  1497. @property
  1498. def test_file_path(self):
  1499. return self._test_file_path
  1500. @property
  1501. def test_file_path_str(self):
  1502. return str(self._test_file_path)
  1503. @property
  1504. def test_file_dir(self):
  1505. return self._test_file_dir
  1506. @property
  1507. def test_file_dir_str(self):
  1508. return str(self._test_file_dir)
  1509. @property
  1510. def tests_dir(self):
  1511. return self._tests_dir
  1512. @property
  1513. def tests_dir_str(self):
  1514. return str(self._tests_dir)
  1515. @property
  1516. def examples_dir(self):
  1517. return self._examples_dir
  1518. @property
  1519. def examples_dir_str(self):
  1520. return str(self._examples_dir)
  1521. @property
  1522. def repo_root_dir(self):
  1523. return self._repo_root_dir
  1524. @property
  1525. def repo_root_dir_str(self):
  1526. return str(self._repo_root_dir)
  1527. @property
  1528. def src_dir(self):
  1529. return self._src_dir
  1530. @property
  1531. def src_dir_str(self):
  1532. return str(self._src_dir)
  1533. def get_env(self):
  1534. """
  1535. Return a copy of the `os.environ` object that sets up `PYTHONPATH` correctly, depending on the test suite it's
  1536. invoked from. This is useful for invoking external programs from the test suite - e.g. distributed training.
  1537. It always inserts `./src` first, then `./tests` or `./examples` depending on the test suite type and finally
  1538. the preset `PYTHONPATH` if any (all full resolved paths).
  1539. """
  1540. env = os.environ.copy()
  1541. paths = [self.repo_root_dir_str, self.src_dir_str]
  1542. if "/examples" in self.test_file_dir_str:
  1543. paths.append(self.examples_dir_str)
  1544. paths.append(env.get("PYTHONPATH", ""))
  1545. env["PYTHONPATH"] = ":".join(paths)
  1546. return env
  1547. def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None, return_pathlib_obj=False):
  1548. """
  1549. Args:
  1550. tmp_dir (`string`, *optional*):
  1551. if `None`:
  1552. - a unique temporary path will be created
  1553. - sets `before=True` if `before` is `None`
  1554. - sets `after=True` if `after` is `None`
  1555. else:
  1556. - `tmp_dir` will be created
  1557. - sets `before=True` if `before` is `None`
  1558. - sets `after=False` if `after` is `None`
  1559. before (`bool`, *optional*):
  1560. If `True` and the `tmp_dir` already exists, make sure to empty it right away if `False` and the
  1561. `tmp_dir` already exists, any existing files will remain there.
  1562. after (`bool`, *optional*):
  1563. If `True`, delete the `tmp_dir` at the end of the test if `False`, leave the `tmp_dir` and its contents
  1564. intact at the end of the test.
  1565. return_pathlib_obj (`bool`, *optional*):
  1566. If `True` will return a pathlib.Path object
  1567. Returns:
  1568. tmp_dir(`string`): either the same value as passed via *tmp_dir* or the path to the auto-selected tmp dir
  1569. """
  1570. if tmp_dir is not None:
  1571. # defining the most likely desired behavior for when a custom path is provided.
  1572. # this most likely indicates the debug mode where we want an easily locatable dir that:
  1573. # 1. gets cleared out before the test (if it already exists)
  1574. # 2. is left intact after the test
  1575. if before is None:
  1576. before = True
  1577. if after is None:
  1578. after = False
  1579. # using provided path
  1580. path = Path(tmp_dir).resolve()
  1581. # to avoid nuking parts of the filesystem, only relative paths are allowed
  1582. if not tmp_dir.startswith("./"):
  1583. raise ValueError(
  1584. f"`tmp_dir` can only be a relative path, i.e. `./some/path`, but received `{tmp_dir}`"
  1585. )
  1586. # ensure the dir is empty to start with
  1587. if before is True and path.exists():
  1588. shutil.rmtree(tmp_dir, ignore_errors=True)
  1589. path.mkdir(parents=True, exist_ok=True)
  1590. else:
  1591. # defining the most likely desired behavior for when a unique tmp path is auto generated
  1592. # (not a debug mode), here we require a unique tmp dir that:
  1593. # 1. is empty before the test (it will be empty in this situation anyway)
  1594. # 2. gets fully removed after the test
  1595. if before is None:
  1596. before = True
  1597. if after is None:
  1598. after = True
  1599. # using unique tmp dir (always empty, regardless of `before`)
  1600. tmp_dir = tempfile.mkdtemp()
  1601. if after is True:
  1602. # register for deletion
  1603. self.teardown_tmp_dirs.append(tmp_dir)
  1604. return Path(tmp_dir).resolve() if return_pathlib_obj else tmp_dir
  1605. def python_one_liner_max_rss(self, one_liner_str):
  1606. """
  1607. Runs the passed python one liner (just the code) and returns how much max cpu memory was used to run the
  1608. program.
  1609. Args:
  1610. one_liner_str (`string`):
  1611. a python one liner code that gets passed to `python -c`
  1612. Returns:
  1613. max cpu memory bytes used to run the program. This value is likely to vary slightly from run to run.
  1614. Requirements:
  1615. this helper needs `/usr/bin/time` to be installed (`apt install time`)
  1616. Example:
  1617. ```
  1618. one_liner_str = 'from transformers import AutoModel; AutoModel.from_pretrained("google-t5/t5-large")'
  1619. max_rss = self.python_one_liner_max_rss(one_liner_str)
  1620. ```
  1621. """
  1622. if not cmd_exists("/usr/bin/time"):
  1623. raise ValueError("/usr/bin/time is required, install with `apt install time`")
  1624. cmd = shlex.split(f"/usr/bin/time -f %M python -c '{one_liner_str}'")
  1625. with CaptureStd() as cs:
  1626. execute_subprocess_async(cmd, env=self.get_env())
  1627. # returned data is in KB so convert to bytes
  1628. max_rss = int(cs.err.split("\n")[-2].replace("stderr: ", "")) * 1024
  1629. return max_rss
  1630. def tearDown(self):
  1631. # get_auto_remove_tmp_dir feature: remove registered temp dirs
  1632. for path in self.teardown_tmp_dirs:
  1633. shutil.rmtree(path, ignore_errors=True)
  1634. self.teardown_tmp_dirs = []
  1635. if is_accelerate_available():
  1636. AcceleratorState._reset_state()
  1637. PartialState._reset_state()
  1638. # delete all the env variables having `ACCELERATE` in them
  1639. for k in list(os.environ.keys()):
  1640. if "ACCELERATE" in k:
  1641. del os.environ[k]
  1642. def mockenv(**kwargs):
  1643. """
  1644. this is a convenience wrapper, that allows this ::
  1645. @mockenv(RUN_SLOW=True, USE_TF=False) def test_something():
  1646. run_slow = os.getenv("RUN_SLOW", False) use_tf = os.getenv("USE_TF", False)
  1647. """
  1648. return mock.patch.dict(os.environ, kwargs)
  1649. # from https://stackoverflow.com/a/34333710/9201239
  1650. @contextlib.contextmanager
  1651. def mockenv_context(*remove, **update):
  1652. """
  1653. Temporarily updates the `os.environ` dictionary in-place. Similar to mockenv
  1654. The `os.environ` dictionary is updated in-place so that the modification is sure to work in all situations.
  1655. Args:
  1656. remove: Environment variables to remove.
  1657. update: Dictionary of environment variables and values to add/update.
  1658. """
  1659. env = os.environ
  1660. update = update or {}
  1661. remove = remove or []
  1662. # List of environment variables being updated or removed.
  1663. stomped = (set(update.keys()) | set(remove)) & set(env.keys())
  1664. # Environment variables and values to restore on exit.
  1665. update_after = {k: env[k] for k in stomped}
  1666. # Environment variables and values to remove on exit.
  1667. remove_after = frozenset(k for k in update if k not in env)
  1668. try:
  1669. env.update(update)
  1670. [env.pop(k, None) for k in remove]
  1671. yield
  1672. finally:
  1673. env.update(update_after)
  1674. [env.pop(k) for k in remove_after]
  1675. # --- pytest conf functions --- #
  1676. # to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once
  1677. pytest_opt_registered = {}
  1678. def pytest_addoption_shared(parser):
  1679. """
  1680. This function is to be called from `conftest.py` via `pytest_addoption` wrapper that has to be defined there.
  1681. It allows loading both `conftest.py` files at once without causing a failure due to adding the same `pytest`
  1682. option.
  1683. """
  1684. option = "--make-reports"
  1685. if option not in pytest_opt_registered:
  1686. parser.addoption(
  1687. option,
  1688. action="store",
  1689. default=False,
  1690. help="generate report files. The value of this option is used as a prefix to report names",
  1691. )
  1692. pytest_opt_registered[option] = 1
  1693. def pytest_terminal_summary_main(tr, id):
  1694. """
  1695. Generate multiple reports at the end of test suite run - each report goes into a dedicated file in the current
  1696. directory. The report files are prefixed with the test suite name.
  1697. This function emulates --duration and -rA pytest arguments.
  1698. This function is to be called from `conftest.py` via `pytest_terminal_summary` wrapper that has to be defined
  1699. there.
  1700. Args:
  1701. - tr: `terminalreporter` passed from `conftest.py`
  1702. - id: unique id like `tests` or `examples` that will be incorporated into the final reports filenames - this is
  1703. needed as some jobs have multiple runs of pytest, so we can't have them overwrite each other.
  1704. NB: this functions taps into a private _pytest API and while unlikely, it could break should pytest do internal
  1705. changes - also it calls default internal methods of terminalreporter which can be hijacked by various `pytest-`
  1706. plugins and interfere.
  1707. """
  1708. from _pytest.config import create_terminal_writer
  1709. if not len(id):
  1710. id = "tests"
  1711. config = tr.config
  1712. orig_writer = config.get_terminal_writer()
  1713. orig_tbstyle = config.option.tbstyle
  1714. orig_reportchars = tr.reportchars
  1715. dir = f"reports/{id}"
  1716. Path(dir).mkdir(parents=True, exist_ok=True)
  1717. report_files = {
  1718. k: f"{dir}/{k}.txt"
  1719. for k in [
  1720. "durations",
  1721. "errors",
  1722. "failures_long",
  1723. "failures_short",
  1724. "failures_line",
  1725. "passes",
  1726. "stats",
  1727. "summary_short",
  1728. "warnings",
  1729. ]
  1730. }
  1731. # custom durations report
  1732. # note: there is no need to call pytest --durations=XX to get this separate report
  1733. # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/runner.py#L66
  1734. dlist = []
  1735. for replist in tr.stats.values():
  1736. for rep in replist:
  1737. if hasattr(rep, "duration"):
  1738. dlist.append(rep)
  1739. if dlist:
  1740. dlist.sort(key=lambda x: x.duration, reverse=True)
  1741. with open(report_files["durations"], "w") as f:
  1742. durations_min = 0.05 # sec
  1743. f.write("slowest durations\n")
  1744. for i, rep in enumerate(dlist):
  1745. if rep.duration < durations_min:
  1746. f.write(f"{len(dlist) - i} durations < {durations_min} secs were omitted")
  1747. break
  1748. f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n")
  1749. def summary_failures_short(tr):
  1750. # expecting that the reports were --tb=long (default) so we chop them off here to the last frame
  1751. reports = tr.getreports("failed")
  1752. if not reports:
  1753. return
  1754. tr.write_sep("=", "FAILURES SHORT STACK")
  1755. for rep in reports:
  1756. msg = tr._getfailureheadline(rep)
  1757. tr.write_sep("_", msg, red=True, bold=True)
  1758. # chop off the optional leading extra frames, leaving only the last one
  1759. longrepr = re.sub(r".*_ _ _ (_ ){10,}_ _ ", "", rep.longreprtext, 0, re.MULTILINE | re.DOTALL)
  1760. tr._tw.line(longrepr)
  1761. # note: not printing out any rep.sections to keep the report short
  1762. # use ready-made report funcs, we are just hijacking the filehandle to log to a dedicated file each
  1763. # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/terminal.py#L814
  1764. # note: some pytest plugins may interfere by hijacking the default `terminalreporter` (e.g.
  1765. # pytest-instafail does that)
  1766. # report failures with line/short/long styles
  1767. config.option.tbstyle = "auto" # full tb
  1768. with open(report_files["failures_long"], "w") as f:
  1769. tr._tw = create_terminal_writer(config, f)
  1770. tr.summary_failures()
  1771. # config.option.tbstyle = "short" # short tb
  1772. with open(report_files["failures_short"], "w") as f:
  1773. tr._tw = create_terminal_writer(config, f)
  1774. summary_failures_short(tr)
  1775. config.option.tbstyle = "line" # one line per error
  1776. with open(report_files["failures_line"], "w") as f:
  1777. tr._tw = create_terminal_writer(config, f)
  1778. tr.summary_failures()
  1779. with open(report_files["errors"], "w") as f:
  1780. tr._tw = create_terminal_writer(config, f)
  1781. tr.summary_errors()
  1782. with open(report_files["warnings"], "w") as f:
  1783. tr._tw = create_terminal_writer(config, f)
  1784. tr.summary_warnings() # normal warnings
  1785. tr.summary_warnings() # final warnings
  1786. tr.reportchars = "wPpsxXEf" # emulate -rA (used in summary_passes() and short_test_summary())
  1787. # Skip the `passes` report, as it starts to take more than 5 minutes, and sometimes it timeouts on CircleCI if it
  1788. # takes > 10 minutes (as this part doesn't generate any output on the terminal).
  1789. # (also, it seems there is no useful information in this report, and we rarely need to read it)
  1790. # with open(report_files["passes"], "w") as f:
  1791. # tr._tw = create_terminal_writer(config, f)
  1792. # tr.summary_passes()
  1793. with open(report_files["summary_short"], "w") as f:
  1794. tr._tw = create_terminal_writer(config, f)
  1795. tr.short_test_summary()
  1796. with open(report_files["stats"], "w") as f:
  1797. tr._tw = create_terminal_writer(config, f)
  1798. tr.summary_stats()
  1799. # restore:
  1800. tr._tw = orig_writer
  1801. tr.reportchars = orig_reportchars
  1802. config.option.tbstyle = orig_tbstyle
  1803. # --- distributed testing functions --- #
  1804. # adapted from https://stackoverflow.com/a/59041913/9201239
  1805. import asyncio # noqa
  1806. class _RunOutput:
  1807. def __init__(self, returncode, stdout, stderr):
  1808. self.returncode = returncode
  1809. self.stdout = stdout
  1810. self.stderr = stderr
  1811. async def _read_stream(stream, callback):
  1812. while True:
  1813. line = await stream.readline()
  1814. if line:
  1815. callback(line)
  1816. else:
  1817. break
  1818. async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput:
  1819. if echo:
  1820. print("\nRunning: ", " ".join(cmd))
  1821. p = await asyncio.create_subprocess_exec(
  1822. cmd[0],
  1823. *cmd[1:],
  1824. stdin=stdin,
  1825. stdout=asyncio.subprocess.PIPE,
  1826. stderr=asyncio.subprocess.PIPE,
  1827. env=env,
  1828. )
  1829. # note: there is a warning for a possible deadlock when using `wait` with huge amounts of data in the pipe
  1830. # https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait
  1831. #
  1832. # If it starts hanging, will need to switch to the following code. The problem is that no data
  1833. # will be seen until it's done and if it hangs for example there will be no debug info.
  1834. # out, err = await p.communicate()
  1835. # return _RunOutput(p.returncode, out, err)
  1836. out = []
  1837. err = []
  1838. def tee(line, sink, pipe, label=""):
  1839. line = line.decode("utf-8").rstrip()
  1840. sink.append(line)
  1841. if not quiet:
  1842. print(label, line, file=pipe)
  1843. # XXX: the timeout doesn't seem to make any difference here
  1844. await asyncio.wait(
  1845. [
  1846. asyncio.create_task(_read_stream(p.stdout, lambda l: tee(l, out, sys.stdout, label="stdout:"))),
  1847. asyncio.create_task(_read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:"))),
  1848. ],
  1849. timeout=timeout,
  1850. )
  1851. return _RunOutput(await p.wait(), out, err)
  1852. def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False, echo=True) -> _RunOutput:
  1853. loop = asyncio.get_event_loop()
  1854. result = loop.run_until_complete(
  1855. _stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo)
  1856. )
  1857. cmd_str = " ".join(cmd)
  1858. if result.returncode > 0:
  1859. stderr = "\n".join(result.stderr)
  1860. raise RuntimeError(
  1861. f"'{cmd_str}' failed with returncode {result.returncode}\n\n"
  1862. f"The combined stderr from workers follows:\n{stderr}"
  1863. )
  1864. # check that the subprocess actually did run and produced some output, should the test rely on
  1865. # the remote side to do the testing
  1866. if not result.stdout and not result.stderr:
  1867. raise RuntimeError(f"'{cmd_str}' produced no output.")
  1868. return result
  1869. def pytest_xdist_worker_id():
  1870. """
  1871. Returns an int value of worker's numerical id under `pytest-xdist`'s concurrent workers `pytest -n N` regime, or 0
  1872. if `-n 1` or `pytest-xdist` isn't being used.
  1873. """
  1874. worker = os.environ.get("PYTEST_XDIST_WORKER", "gw0")
  1875. worker = re.sub(r"^gw", "", worker, 0, re.MULTILINE)
  1876. return int(worker)
  1877. def get_torch_dist_unique_port():
  1878. """
  1879. Returns a free port number that can be fed to `torch.distributed.launch`'s `--master_port` argument.
  1880. Binds to port 0 to let the OS assign an available port, avoiding collisions from hardcoded ports
  1881. and TCP TIME_WAIT issues between sequential subprocess launches.
  1882. """
  1883. import socket
  1884. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  1885. s.bind(("", 0))
  1886. return s.getsockname()[1]
  1887. def nested_simplify(obj, decimals=3):
  1888. """
  1889. Simplifies an object by rounding float numbers, and downcasting tensors/numpy arrays to get simple equality test
  1890. within tests.
  1891. """
  1892. import numpy as np
  1893. if isinstance(obj, list):
  1894. return [nested_simplify(item, decimals) for item in obj]
  1895. if isinstance(obj, tuple):
  1896. return tuple(nested_simplify(item, decimals) for item in obj)
  1897. elif isinstance(obj, np.ndarray):
  1898. return nested_simplify(obj.tolist())
  1899. elif isinstance(obj, Mapping):
  1900. return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()}
  1901. elif isinstance(obj, (str, int, np.int64)) or obj is None:
  1902. return obj
  1903. elif is_torch_available() and isinstance(obj, torch.Tensor):
  1904. return nested_simplify(obj.tolist(), decimals)
  1905. elif isinstance(obj, float):
  1906. return round(obj, decimals)
  1907. elif isinstance(obj, (np.int32, np.float32, np.float16)):
  1908. return nested_simplify(obj.item(), decimals)
  1909. else:
  1910. raise Exception(f"Not supported: {type(obj)}")
  1911. def check_json_file_has_correct_format(file_path):
  1912. with open(file_path) as f:
  1913. lines = f.readlines()
  1914. if len(lines) == 1:
  1915. # length can only be 1 if dict is empty
  1916. assert lines[0] == "{}"
  1917. else:
  1918. # otherwise make sure json has correct format (at least 3 lines)
  1919. assert len(lines) >= 3
  1920. # each key one line, ident should be 2, min length is 3
  1921. assert lines[0].strip() == "{"
  1922. for line in lines[1:-1]:
  1923. left_indent = len(lines[1]) - len(lines[1].lstrip())
  1924. assert left_indent == 2
  1925. assert lines[-1].strip() == "}"
  1926. def to_2tuple(x):
  1927. if isinstance(x, collections.abc.Iterable):
  1928. return x
  1929. return (x, x)
  1930. # These utils relate to ensuring the right error message is received when running scripts
  1931. class SubprocessCallException(Exception):
  1932. pass
  1933. def run_command(command: list[str], return_stdout=False):
  1934. """
  1935. Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
  1936. if an error occurred while running `command`
  1937. """
  1938. try:
  1939. output = subprocess.check_output(command, stderr=subprocess.STDOUT)
  1940. if return_stdout:
  1941. if hasattr(output, "decode"):
  1942. output = output.decode("utf-8")
  1943. return output
  1944. except subprocess.CalledProcessError as e:
  1945. raise SubprocessCallException(
  1946. f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
  1947. ) from e
  1948. class RequestCounter:
  1949. """
  1950. Helper class that will count all requests made online.
  1951. Might not be robust if urllib3 changes its logging format but should be good enough for us.
  1952. Usage:
  1953. ```py
  1954. with RequestCounter() as counter:
  1955. _ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
  1956. assert counter["GET"] == 0
  1957. assert counter["HEAD"] == 1
  1958. assert counter.total_calls == 1
  1959. ```
  1960. """
  1961. def __enter__(self):
  1962. self._counter = defaultdict(int)
  1963. self._thread_id = threading.get_ident()
  1964. self._extra_info = []
  1965. def patched_with_thread_info(func):
  1966. def wrap(*args, **kwargs):
  1967. self._extra_info.append(threading.get_ident())
  1968. return func(*args, **kwargs)
  1969. return wrap
  1970. import urllib3
  1971. self.patcher = patch.object(
  1972. urllib3.connectionpool.log, "debug", side_effect=patched_with_thread_info(urllib3.connectionpool.log.debug)
  1973. )
  1974. self.mock = self.patcher.start()
  1975. return self
  1976. def __exit__(self, *args, **kwargs) -> None:
  1977. assert len(self.mock.call_args_list) == len(self._extra_info)
  1978. for thread_id, call in zip(self._extra_info, self.mock.call_args_list):
  1979. if thread_id != self._thread_id:
  1980. continue
  1981. # code 307: the URL being requested by the user has moved to a temporary location
  1982. if call.args[-2] == 307:
  1983. continue
  1984. log = call.args[0] % call.args[1:]
  1985. for method in ("HEAD", "GET", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"):
  1986. if method in log:
  1987. self._counter[method] += 1
  1988. break
  1989. self.patcher.stop()
  1990. def __getitem__(self, key: str) -> int:
  1991. return self._counter[key]
  1992. @property
  1993. def total_calls(self) -> int:
  1994. return sum(self._counter.values())
  1995. def is_flaky(max_attempts: int = 5, wait_before_retry: float | None = None, description: str | None = None):
  1996. """
  1997. To decorate flaky tests. They will be retried on failures.
  1998. Please note that our push tests use `pytest-rerunfailures`, which prompts the CI to rerun certain types of
  1999. failed tests. More specifically, if the test exception contains any substring in `FLAKY_TEST_FAILURE_PATTERNS`
  2000. (in `.circleci/create_circleci_config.py`), it will be rerun. If you find a recurrent pattern of failures,
  2001. expand `FLAKY_TEST_FAILURE_PATTERNS` in our CI configuration instead of using `is_flaky`.
  2002. Args:
  2003. max_attempts (`int`, *optional*, defaults to 5):
  2004. The maximum number of attempts to retry the flaky test.
  2005. wait_before_retry (`float`, *optional*):
  2006. If provided, will wait that number of seconds before retrying the test.
  2007. description (`str`, *optional*):
  2008. A string to describe the situation (what / where / why is flaky, link to GH issue/PR comments, errors,
  2009. etc.)
  2010. """
  2011. def decorator(test_func_ref):
  2012. @functools.wraps(test_func_ref)
  2013. def wrapper(*args, **kwargs):
  2014. retry_count = 1
  2015. while retry_count < max_attempts:
  2016. try:
  2017. return test_func_ref(*args, **kwargs)
  2018. except Exception as err:
  2019. logger.error(f"Test failed with {err} at try {retry_count}/{max_attempts}.")
  2020. if wait_before_retry is not None:
  2021. time.sleep(wait_before_retry)
  2022. retry_count += 1
  2023. return test_func_ref(*args, **kwargs)
  2024. return unittest.skipUnless(_run_flaky_tests, "test is flaky")(wrapper)
  2025. return decorator
  2026. def hub_retry(max_attempts: int = 5, wait_before_retry: float | None = 2):
  2027. """
  2028. To decorate tests that download from the Hub. They can fail due to a
  2029. variety of network issues such as timeouts, connection resets, etc.
  2030. Args:
  2031. max_attempts (`int`, *optional*, defaults to 5):
  2032. The maximum number of attempts to retry the flaky test.
  2033. wait_before_retry (`float`, *optional*, defaults to 2):
  2034. If provided, will wait that number of seconds before retrying the test.
  2035. """
  2036. def decorator(test_func_ref):
  2037. @functools.wraps(test_func_ref)
  2038. def wrapper(*args, **kwargs):
  2039. retry_count = 1
  2040. while retry_count < max_attempts:
  2041. try:
  2042. return test_func_ref(*args, **kwargs)
  2043. # We catch all exceptions related to network issues from httpx
  2044. except (
  2045. httpx.HTTPError,
  2046. httpx.RequestError,
  2047. httpx.TimeoutException,
  2048. httpx.ReadTimeout,
  2049. httpx.ConnectError,
  2050. httpx.NetworkError,
  2051. ) as err:
  2052. logger.error(
  2053. f"Test failed with {err} at try {retry_count}/{max_attempts} as it couldn't connect to the specified Hub repository."
  2054. )
  2055. if wait_before_retry is not None:
  2056. time.sleep(wait_before_retry)
  2057. retry_count += 1
  2058. return test_func_ref(*args, **kwargs)
  2059. return wrapper
  2060. return decorator
  2061. def run_first(test_case):
  2062. """
  2063. Decorator marking a test with order(1). When pytest-order plugin is installed, tests marked with this decorator
  2064. are guaranteed to run first.
  2065. This is especially useful in some test settings like on a Gaudi instance where a Gaudi device can only be used by a
  2066. single process at a time. So we make sure all tests that run in a subprocess are launched first, to avoid device
  2067. allocation conflicts.
  2068. """
  2069. # Without this check, we get unwanted warnings when it's not installed
  2070. if is_pytest_order_available():
  2071. import pytest
  2072. return pytest.mark.order(1)(test_case)
  2073. else:
  2074. return test_case
  2075. def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
  2076. """
  2077. To run a test in a subprocess. In particular, this can avoid (GPU) memory issue.
  2078. Args:
  2079. test_case (`unittest.TestCase`):
  2080. The test that will run `target_func`.
  2081. target_func (`Callable`):
  2082. The function implementing the actual testing logic.
  2083. inputs (`dict`, *optional*, defaults to `None`):
  2084. The inputs that will be passed to `target_func` through an (input) queue.
  2085. timeout (`int`, *optional*, defaults to `None`):
  2086. The timeout (in seconds) that will be passed to the input and output queues. If not specified, the env.
  2087. variable `PYTEST_TIMEOUT` will be checked. If still `None`, its value will be set to `600`.
  2088. """
  2089. if timeout is None:
  2090. timeout = int(os.environ.get("PYTEST_TIMEOUT", "600"))
  2091. start_methohd = "spawn"
  2092. ctx = multiprocessing.get_context(start_methohd)
  2093. input_queue = ctx.Queue(1)
  2094. output_queue = ctx.JoinableQueue(1)
  2095. # We can't send `unittest.TestCase` to the child, otherwise we get issues regarding pickle.
  2096. input_queue.put(inputs, timeout=timeout)
  2097. process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout))
  2098. process.start()
  2099. # Kill the child process if we can't get outputs from it in time: otherwise, the hanging subprocess prevents
  2100. # the test to exit properly.
  2101. try:
  2102. results = output_queue.get(timeout=timeout)
  2103. output_queue.task_done()
  2104. except Exception as e:
  2105. process.terminate()
  2106. test_case.fail(e)
  2107. process.join(timeout=timeout)
  2108. if results["error"] is not None:
  2109. test_case.fail(f"{results['error']}")
  2110. def run_test_using_subprocess(func):
  2111. """
  2112. To decorate a test to run in a subprocess using the `subprocess` module. This could avoid potential GPU memory
  2113. issues (GPU OOM or a test that causes many subsequential failing with `CUDA error: device-side assert triggered`).
  2114. """
  2115. import pytest
  2116. @functools.wraps(func)
  2117. def wrapper(*args, **kwargs):
  2118. if os.getenv("_INSIDE_SUB_PROCESS", None) == "1":
  2119. func(*args, **kwargs)
  2120. else:
  2121. test = " ".join(os.environ.get("PYTEST_CURRENT_TEST").split(" ")[:-1])
  2122. try:
  2123. env = copy.deepcopy(os.environ)
  2124. env["_INSIDE_SUB_PROCESS"] = "1"
  2125. # This prevents the entries in `short test summary info` given by the subprocess being truncated. so the
  2126. # full information can be passed to the parent pytest process.
  2127. # See: https://docs.pytest.org/en/stable/explanation/ci.html
  2128. env["CI"] = "true"
  2129. # If not subclass of `unitTest.TestCase` and `pytestconfig` is used: try to grab and use the arguments
  2130. if "pytestconfig" in kwargs:
  2131. command = list(kwargs["pytestconfig"].invocation_params.args)
  2132. for idx, x in enumerate(command):
  2133. if x in kwargs["pytestconfig"].args:
  2134. test = test.split("::")[1:]
  2135. command[idx] = "::".join([f"{func.__globals__['__file__']}"] + test)
  2136. command = [f"{sys.executable}", "-m", "pytest"] + command
  2137. command = [x for x in command if x != "--no-summary"]
  2138. # Otherwise, simply run the test with no option at all
  2139. else:
  2140. command = [f"{sys.executable}", "-m", "pytest", f"{test}"]
  2141. subprocess.run(command, env=env, check=True, capture_output=True)
  2142. except subprocess.CalledProcessError as e:
  2143. exception_message = e.stdout.decode()
  2144. lines = exception_message.split("\n")
  2145. # Add a first line with more informative information instead of just `= test session starts =`.
  2146. # This makes the `short test summary info` section more useful.
  2147. if "= test session starts =" in lines[0]:
  2148. text = ""
  2149. for line in lines[1:]:
  2150. if line.startswith("FAILED "):
  2151. text = line[len("FAILED ") :]
  2152. text = "".join(text.split(" - ")[1:])
  2153. elif line.startswith("=") and line.endswith("=") and " failed in " in line:
  2154. break
  2155. elif len(text) > 0:
  2156. text += f"\n{line}"
  2157. text = "(subprocess) " + text
  2158. lines = [text] + lines
  2159. exception_message = "\n".join(lines)
  2160. raise pytest.fail(exception_message, pytrace=False)
  2161. return wrapper
  2162. """
  2163. The following contains utils to run the documentation tests without having to overwrite any files.
  2164. The `preprocess_string` function adds `# doctest: +IGNORE_RESULT` markers on the fly anywhere a `load_dataset` call is
  2165. made as a print would otherwise fail the corresponding line.
  2166. To skip cuda tests, make sure to call `SKIP_CUDA_DOCTEST=1 pytest --doctest-modules <path_to_files_to_test>
  2167. """
  2168. def preprocess_string(string, skip_cuda_tests):
  2169. """Prepare a docstring or a `.md` file to be run by doctest.
  2170. The argument `string` would be the whole file content if it is a `.md` file. For a python file, it would be one of
  2171. its docstring. In each case, it may contain multiple python code examples. If `skip_cuda_tests` is `True` and a
  2172. cuda stuff is detective (with a heuristic), this method will return an empty string so no doctest will be run for
  2173. `string`.
  2174. """
  2175. codeblock_pattern = r"(```(?:python|py)\s*\n\s*>>> )(.*?```)"
  2176. codeblocks = re.split(codeblock_pattern, string, flags=re.DOTALL)
  2177. is_cuda_found = False
  2178. for i, codeblock in enumerate(codeblocks):
  2179. if "load_dataset(" in codeblock and "# doctest: +IGNORE_RESULT" not in codeblock:
  2180. codeblocks[i] = re.sub(r"(>>> .*load_dataset\(.*)", r"\1 # doctest: +IGNORE_RESULT", codeblock)
  2181. if (
  2182. (">>>" in codeblock or "..." in codeblock)
  2183. and re.search(r"cuda|to\(0\)|device=0", codeblock)
  2184. and skip_cuda_tests
  2185. ):
  2186. is_cuda_found = True
  2187. break
  2188. modified_string = ""
  2189. if not is_cuda_found:
  2190. modified_string = "".join(codeblocks)
  2191. return modified_string
  2192. class HfDocTestParser(doctest.DocTestParser):
  2193. """
  2194. Overwrites the DocTestParser from doctest to properly parse the codeblocks that are formatted with black. This
  2195. means that there are no extra lines at the end of our snippets. The `# doctest: +IGNORE_RESULT` marker is also
  2196. added anywhere a `load_dataset` call is made as a print would otherwise fail the corresponding line.
  2197. Tests involving cuda are skipped base on a naive pattern that should be updated if it is not enough.
  2198. """
  2199. # This regular expression is used to find doctest examples in a
  2200. # string. It defines three groups: `source` is the source code
  2201. # (including leading indentation and prompts); `indent` is the
  2202. # indentation of the first (PS1) line of the source code; and
  2203. # `want` is the expected output (including leading indentation).
  2204. # fmt: off
  2205. _EXAMPLE_RE = re.compile(r'''
  2206. # Source consists of a PS1 line followed by zero or more PS2 lines.
  2207. (?P<source>
  2208. (?:^(?P<indent> [ ]*) >>> .*) # PS1 line
  2209. (?:\n [ ]* \.\.\. .*)*) # PS2 lines
  2210. \n?
  2211. # Want consists of any non-blank lines that do not start with PS1.
  2212. (?P<want> (?:(?![ ]*$) # Not a blank line
  2213. (?![ ]*>>>) # Not a line starting with PS1
  2214. # !!!!!!!!!!! HF Specific !!!!!!!!!!!
  2215. (?:(?!```).)* # Match any character except '`' until a '```' is found (this is specific to HF because black removes the last line)
  2216. # !!!!!!!!!!! HF Specific !!!!!!!!!!!
  2217. (?:\n|$) # Match a new line or end of string
  2218. )*)
  2219. ''', re.MULTILINE | re.VERBOSE
  2220. )
  2221. # fmt: on
  2222. # !!!!!!!!!!! HF Specific !!!!!!!!!!!
  2223. skip_cuda_tests: bool = os.environ.get("SKIP_CUDA_DOCTEST", "0") == "1"
  2224. # !!!!!!!!!!! HF Specific !!!!!!!!!!!
  2225. def parse(self, string, name="<string>"):
  2226. """
  2227. Overwrites the `parse` method to incorporate a skip for CUDA tests, and remove logs and dataset prints before
  2228. calling `super().parse`
  2229. """
  2230. string = preprocess_string(string, self.skip_cuda_tests)
  2231. return super().parse(string, name)
  2232. class HfDoctestModule(Module):
  2233. """
  2234. Overwrites the `DoctestModule` of the pytest package to make sure the HFDocTestParser is used when discovering
  2235. tests.
  2236. """
  2237. def collect(self) -> Iterable[DoctestItem]:
  2238. class MockAwareDocTestFinder(doctest.DocTestFinder):
  2239. """A hackish doctest finder that overrides stdlib internals to fix a stdlib bug.
  2240. https://github.com/pytest-dev/pytest/issues/3456 https://bugs.python.org/issue25532
  2241. """
  2242. def _find_lineno(self, obj, source_lines):
  2243. """Doctest code does not take into account `@property`, this
  2244. is a hackish way to fix it. https://bugs.python.org/issue17446
  2245. Wrapped Doctests will need to be unwrapped so the correct line number is returned. This will be
  2246. reported upstream. #8796
  2247. """
  2248. if isinstance(obj, property):
  2249. obj = getattr(obj, "fget", obj)
  2250. if hasattr(obj, "__wrapped__"):
  2251. # Get the main obj in case of it being wrapped
  2252. obj = inspect.unwrap(obj)
  2253. # Type ignored because this is a private function.
  2254. return super()._find_lineno( # type:ignore[misc]
  2255. obj,
  2256. source_lines,
  2257. )
  2258. def _find(self, tests, obj, name, module, source_lines, globs, seen) -> None:
  2259. if _is_mocked(obj):
  2260. return
  2261. with _patch_unwrap_mock_aware():
  2262. # Type ignored because this is a private function.
  2263. super()._find( # type:ignore[misc]
  2264. tests, obj, name, module, source_lines, globs, seen
  2265. )
  2266. if self.path.name == "conftest.py":
  2267. module = self.config.pluginmanager._importconftest(
  2268. self.path,
  2269. self.config.getoption("importmode"),
  2270. rootpath=self.config.rootpath,
  2271. )
  2272. else:
  2273. try:
  2274. module = import_path(
  2275. self.path,
  2276. root=self.config.rootpath,
  2277. mode=self.config.getoption("importmode"),
  2278. )
  2279. except ImportError:
  2280. if self.config.getvalue("doctest_ignore_import_errors"):
  2281. skip("unable to import module %r" % self.path)
  2282. else:
  2283. raise
  2284. # !!!!!!!!!!! HF Specific !!!!!!!!!!!
  2285. finder = MockAwareDocTestFinder(parser=HfDocTestParser())
  2286. # !!!!!!!!!!! HF Specific !!!!!!!!!!!
  2287. optionflags = get_optionflags(self)
  2288. runner = _get_runner(
  2289. verbose=False,
  2290. optionflags=optionflags,
  2291. checker=_get_checker(),
  2292. continue_on_failure=_get_continue_on_failure(self.config),
  2293. )
  2294. for test in finder.find(module, module.__name__):
  2295. if test.examples: # skip empty doctests and cuda
  2296. yield DoctestItem.from_parent(self, name=test.name, runner=runner, dtest=test)
  2297. def _device_agnostic_dispatch(device: str, dispatch_table: dict[str, Callable], *args, **kwargs):
  2298. if device not in dispatch_table:
  2299. if not callable(dispatch_table["default"]):
  2300. return dispatch_table["default"]
  2301. return dispatch_table["default"](*args, **kwargs)
  2302. fn = dispatch_table[device]
  2303. # Some device agnostic functions return values or None, will return then directly.
  2304. if not callable(fn):
  2305. return fn
  2306. return fn(*args, **kwargs)
  2307. if is_torch_available():
  2308. # Mappings from device names to callable functions to support device agnostic
  2309. # testing.
  2310. BACKEND_MANUAL_SEED = {
  2311. "cuda": torch.cuda.manual_seed,
  2312. "cpu": torch.manual_seed,
  2313. "default": torch.manual_seed,
  2314. }
  2315. BACKEND_EMPTY_CACHE = {
  2316. "cuda": torch.cuda.empty_cache,
  2317. "cpu": None,
  2318. "default": None,
  2319. }
  2320. BACKEND_DEVICE_COUNT = {
  2321. "cuda": torch.cuda.device_count,
  2322. "cpu": lambda: 0,
  2323. "default": lambda: 1,
  2324. }
  2325. BACKEND_RESET_MAX_MEMORY_ALLOCATED = {
  2326. "cuda": torch.cuda.reset_max_memory_allocated,
  2327. "cpu": None,
  2328. "default": None,
  2329. }
  2330. BACKEND_MAX_MEMORY_ALLOCATED = {
  2331. "cuda": torch.cuda.max_memory_allocated,
  2332. "cpu": 0,
  2333. "default": 0,
  2334. }
  2335. BACKEND_RESET_PEAK_MEMORY_STATS = {
  2336. "cuda": torch.cuda.reset_peak_memory_stats,
  2337. "cpu": None,
  2338. "default": None,
  2339. }
  2340. BACKEND_MEMORY_ALLOCATED = {
  2341. "cuda": torch.cuda.memory_allocated,
  2342. "cpu": 0,
  2343. "default": 0,
  2344. }
  2345. BACKEND_SYNCHRONIZE = {
  2346. "cuda": torch.cuda.synchronize,
  2347. "cpu": None,
  2348. "default": None,
  2349. }
  2350. BACKEND_TORCH_ACCELERATOR_MODULE = {
  2351. "cuda": torch.cuda,
  2352. "cpu": None,
  2353. "default": None,
  2354. }
  2355. else:
  2356. BACKEND_MANUAL_SEED = {"default": None}
  2357. BACKEND_EMPTY_CACHE = {"default": None}
  2358. BACKEND_DEVICE_COUNT = {"default": lambda: 0}
  2359. BACKEND_RESET_MAX_MEMORY_ALLOCATED = {"default": None}
  2360. BACKEND_RESET_PEAK_MEMORY_STATS = {"default": None}
  2361. BACKEND_MAX_MEMORY_ALLOCATED = {"default": 0}
  2362. BACKEND_MEMORY_ALLOCATED = {"default": 0}
  2363. BACKEND_SYNCHRONIZE = {"default": None}
  2364. BACKEND_TORCH_ACCELERATOR_MODULE = {"default": None}
  2365. if is_torch_hpu_available():
  2366. BACKEND_MANUAL_SEED["hpu"] = torch.hpu.manual_seed
  2367. BACKEND_DEVICE_COUNT["hpu"] = torch.hpu.device_count
  2368. BACKEND_TORCH_ACCELERATOR_MODULE["hpu"] = torch.hpu
  2369. if is_torch_mlu_available():
  2370. BACKEND_EMPTY_CACHE["mlu"] = torch.mlu.empty_cache
  2371. BACKEND_MANUAL_SEED["mlu"] = torch.mlu.manual_seed
  2372. BACKEND_DEVICE_COUNT["mlu"] = torch.mlu.device_count
  2373. BACKEND_TORCH_ACCELERATOR_MODULE["mlu"] = torch.mlu
  2374. if is_torch_npu_available():
  2375. BACKEND_EMPTY_CACHE["npu"] = torch.npu.empty_cache
  2376. BACKEND_MANUAL_SEED["npu"] = torch.npu.manual_seed
  2377. BACKEND_DEVICE_COUNT["npu"] = torch.npu.device_count
  2378. BACKEND_TORCH_ACCELERATOR_MODULE["npu"] = torch.npu
  2379. if is_torch_xpu_available():
  2380. BACKEND_EMPTY_CACHE["xpu"] = torch.xpu.empty_cache
  2381. BACKEND_MANUAL_SEED["xpu"] = torch.xpu.manual_seed
  2382. BACKEND_DEVICE_COUNT["xpu"] = torch.xpu.device_count
  2383. BACKEND_RESET_MAX_MEMORY_ALLOCATED["xpu"] = torch.xpu.reset_peak_memory_stats
  2384. BACKEND_RESET_PEAK_MEMORY_STATS["xpu"] = torch.xpu.reset_peak_memory_stats
  2385. BACKEND_MAX_MEMORY_ALLOCATED["xpu"] = torch.xpu.max_memory_allocated
  2386. BACKEND_MEMORY_ALLOCATED["xpu"] = torch.xpu.memory_allocated
  2387. BACKEND_SYNCHRONIZE["xpu"] = torch.xpu.synchronize
  2388. BACKEND_TORCH_ACCELERATOR_MODULE["xpu"] = torch.xpu
  2389. if is_torch_xla_available():
  2390. BACKEND_EMPTY_CACHE["xla"] = torch.cuda.empty_cache
  2391. BACKEND_MANUAL_SEED["xla"] = torch.cuda.manual_seed
  2392. BACKEND_DEVICE_COUNT["xla"] = torch.cuda.device_count
  2393. def backend_manual_seed(device: str, seed: int):
  2394. return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed)
  2395. def backend_empty_cache(device: str):
  2396. return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE)
  2397. def backend_device_count(device: str):
  2398. return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)
  2399. def backend_reset_max_memory_allocated(device: str):
  2400. return _device_agnostic_dispatch(device, BACKEND_RESET_MAX_MEMORY_ALLOCATED)
  2401. def backend_reset_peak_memory_stats(device: str):
  2402. return _device_agnostic_dispatch(device, BACKEND_RESET_PEAK_MEMORY_STATS)
  2403. def backend_max_memory_allocated(device: str):
  2404. return _device_agnostic_dispatch(device, BACKEND_MAX_MEMORY_ALLOCATED)
  2405. def backend_memory_allocated(device: str):
  2406. return _device_agnostic_dispatch(device, BACKEND_MEMORY_ALLOCATED)
  2407. def backend_synchronize(device: str):
  2408. return _device_agnostic_dispatch(device, BACKEND_SYNCHRONIZE)
  2409. def backend_torch_accelerator_module(device: str):
  2410. return _device_agnostic_dispatch(device, BACKEND_TORCH_ACCELERATOR_MODULE)
  2411. if is_torch_available():
  2412. # If `TRANSFORMERS_TEST_DEVICE_SPEC` is enabled we need to import extra entries
  2413. # into device to function mappings.
  2414. if "TRANSFORMERS_TEST_DEVICE_SPEC" in os.environ:
  2415. device_spec_path = os.environ["TRANSFORMERS_TEST_DEVICE_SPEC"]
  2416. if not Path(device_spec_path).is_file():
  2417. raise ValueError(
  2418. f"Specified path to device spec file is not a file or not found. Received '{device_spec_path}"
  2419. )
  2420. # Try to strip extension for later import – also verifies we are importing a
  2421. # python file.
  2422. device_spec_dir, _ = os.path.split(os.path.realpath(device_spec_path))
  2423. sys.path.append(device_spec_dir)
  2424. try:
  2425. import_name = device_spec_path[: device_spec_path.index(".py")]
  2426. except ValueError as e:
  2427. raise ValueError(f"Provided device spec file was not a Python file! Received '{device_spec_path}") from e
  2428. device_spec_module = importlib.import_module(import_name)
  2429. # Imported file must contain `DEVICE_NAME`. If it doesn't, terminate early.
  2430. try:
  2431. device_name = device_spec_module.DEVICE_NAME
  2432. except AttributeError as e:
  2433. raise AttributeError("Device spec file did not contain `DEVICE_NAME`") from e
  2434. if "TRANSFORMERS_TEST_DEVICE" in os.environ and torch_device != device_name:
  2435. msg = f"Mismatch between environment variable `TRANSFORMERS_TEST_DEVICE` '{torch_device}' and device found in spec '{device_name}'\n"
  2436. msg += "Either unset `TRANSFORMERS_TEST_DEVICE` or ensure it matches device spec name."
  2437. raise ValueError(msg)
  2438. torch_device = device_name
  2439. def update_mapping_from_spec(device_fn_dict: dict[str, Callable], attribute_name: str):
  2440. try:
  2441. # Try to import the function directly
  2442. spec_fn = getattr(device_spec_module, attribute_name)
  2443. device_fn_dict[torch_device] = spec_fn
  2444. except AttributeError as e:
  2445. # If the function doesn't exist, and there is no default, throw an error
  2446. if "default" not in device_fn_dict:
  2447. raise AttributeError(
  2448. f"`{attribute_name}` not found in '{device_spec_path}' and no default fallback function found."
  2449. ) from e
  2450. # Add one entry here for each `BACKEND_*` dictionary.
  2451. update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN")
  2452. update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
  2453. update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")
  2454. def compare_pipeline_output_to_hub_spec(output, hub_spec):
  2455. missing_keys = []
  2456. unexpected_keys = []
  2457. all_field_names = {field.name for field in fields(hub_spec)}
  2458. matching_keys = sorted([key for key in output if key in all_field_names])
  2459. # Fields with a MISSING default are required and must be in the output
  2460. for field in fields(hub_spec):
  2461. if field.default is MISSING and field.name not in output:
  2462. missing_keys.append(field.name)
  2463. # All output keys must match either a required or optional field in the Hub spec
  2464. for output_key in output:
  2465. if output_key not in all_field_names:
  2466. unexpected_keys.append(output_key)
  2467. if missing_keys or unexpected_keys:
  2468. error = ["Pipeline output does not match Hub spec!"]
  2469. if matching_keys:
  2470. error.append(f"Matching keys: {matching_keys}")
  2471. if missing_keys:
  2472. error.append(f"Missing required keys in pipeline output: {missing_keys}")
  2473. if unexpected_keys:
  2474. error.append(f"Keys in pipeline output that are not in Hub spec: {unexpected_keys}")
  2475. raise KeyError("\n".join(error))
  2476. @require_torch
  2477. def cleanup(device: str, gc_collect=False):
  2478. if gc_collect:
  2479. gc.collect()
  2480. backend_empty_cache(device)
  2481. torch.compiler.reset()
  2482. # Type definition of key used in `Expectations` class.
  2483. DeviceProperties = tuple[str | None, int | None, int | None]
  2484. # Helper type. Makes creating instances of `Expectations` smoother.
  2485. PackedDeviceProperties = tuple[str | None, None | int | tuple[int, int]]
  2486. @cache
  2487. def get_device_properties() -> DeviceProperties:
  2488. """
  2489. Get environment device properties.
  2490. """
  2491. if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
  2492. import torch
  2493. major, minor = torch.cuda.get_device_capability()
  2494. if IS_ROCM_SYSTEM:
  2495. return ("rocm", major, minor)
  2496. else:
  2497. return ("cuda", major, minor)
  2498. elif IS_XPU_SYSTEM:
  2499. import torch
  2500. # To get more info of the architecture meaning and bit allocation, refer to https://github.com/intel/llvm/blob/sycl/sycl/include/sycl/ext/oneapi/experimental/device_architecture.def
  2501. arch = torch.xpu.get_device_capability()["architecture"]
  2502. gen_mask = 0x000000FF00000000
  2503. gen = (arch & gen_mask) >> 32
  2504. return ("xpu", gen, None)
  2505. elif IS_NPU_SYSTEM:
  2506. return ("npu", None, None)
  2507. else:
  2508. return (torch_device, None, None)
  2509. def unpack_device_properties(
  2510. properties: PackedDeviceProperties | None = None,
  2511. ) -> DeviceProperties:
  2512. """
  2513. Unpack a `PackedDeviceProperties` tuple into consistently formatted `DeviceProperties` tuple. If properties is None, it is fetched.
  2514. """
  2515. if properties is None:
  2516. return get_device_properties()
  2517. device_type, major_minor = properties
  2518. if major_minor is None:
  2519. major, minor = None, None
  2520. elif isinstance(major_minor, int):
  2521. major, minor = major_minor, None
  2522. else:
  2523. major, minor = major_minor
  2524. return device_type, major, minor
  2525. class Expectations(UserDict[PackedDeviceProperties, Any]):
  2526. def get_expectation(self) -> Any:
  2527. """
  2528. Find best matching expectation based on environment device properties. We look at device_type, major and minor
  2529. versions of the drivers. Expectations are stored as a dictionary with keys of the form
  2530. (device_type, (major, minor)). If the major and minor versions are not provided, we use None.
  2531. """
  2532. return self.find_expectation(get_device_properties())
  2533. def unpacked(self) -> list[tuple[DeviceProperties, Any]]:
  2534. return [(unpack_device_properties(k), v) for k, v in self.data.items()]
  2535. @staticmethod
  2536. def is_default(expectation_key: PackedDeviceProperties) -> bool:
  2537. """
  2538. This function returns True if the expectation_key is the Default expectation (None, None).
  2539. When an Expectation dict contains a Default value, it is generally because the test existed before Expectations.
  2540. When we modify a test to use Expectations for a specific hardware, we don't want to affect the tests on other
  2541. hardwares. Thus we set the previous value as the Default expectation with key (None, None) and add a value for
  2542. the specific hardware with key (hardware_type, (major, minor)).
  2543. """
  2544. return all(p is None for p in expectation_key)
  2545. @staticmethod
  2546. def score(properties: DeviceProperties, other: DeviceProperties) -> float:
  2547. """
  2548. Returns score indicating how similar two instances of the `Properties` tuple are.
  2549. Rules are as follows:
  2550. * Matching `type` adds one point, semi-matching `type` adds 0.1 point (e.g. cuda and rocm).
  2551. * If types match, matching `major` adds another point, and then matching `minor` adds another.
  2552. * The Default expectation (None, None) is worth 0.5 point, which is better than semi-matching. More on this
  2553. in the `is_default` function.
  2554. """
  2555. device_type, major, minor = properties
  2556. other_device_type, other_major, other_minor = other
  2557. score = 0
  2558. # Matching device type, maybe major and minor
  2559. if device_type is not None and device_type == other_device_type:
  2560. score += 1
  2561. if major is not None and major == other_major:
  2562. score += 1
  2563. if minor is not None and minor == other_minor:
  2564. score += 1
  2565. # Semi-matching device type, which carries less importance than the default expectation
  2566. elif device_type in ["cuda", "rocm"] and other_device_type in ["cuda", "rocm"]:
  2567. score = 0.1
  2568. # Default expectation
  2569. if Expectations.is_default(other):
  2570. score = 0.5
  2571. return score
  2572. def find_expectation(self, properties: DeviceProperties = (None, None, None)) -> Any:
  2573. """
  2574. Find best matching expectation based on provided device properties. We score each expectation, and to
  2575. distinguish between expectations with the same score, we use the major and minor version numbers, prioritizing
  2576. most recent versions.
  2577. """
  2578. (result_key, result) = max(
  2579. self.unpacked(),
  2580. key=lambda x: (
  2581. Expectations.score(properties, x[0]), # x[0] is a device properties tuple (device_type, major, minor)
  2582. x[0][1] if x[0][1] is not None else -1, # This key is the major version, -1 if major is None
  2583. x[0][2] if x[0][2] is not None else -1, # This key is the minor version, -1 if minor is None
  2584. ),
  2585. )
  2586. if Expectations.score(properties, result_key) == 0:
  2587. raise ValueError(f"No matching expectation found for {properties}")
  2588. return result
  2589. def __repr__(self):
  2590. return f"{self.data}"
  2591. def patch_torch_compile_force_graph():
  2592. """
  2593. Patch `torch.compile` to always use `fullgraph=True`.
  2594. This is useful when some `torch.compile` tests are running with `fullgraph=False` and we want to be able to run
  2595. them with `fullgraph=True` in some occasion (without introducing new tests) to make sure there is no graph break.
  2596. After PR #40137, `CompileConfig.fullgraph` is `False` by default, this patch is necessary.
  2597. """
  2598. force_fullgraph = os.environ.get("TORCH_COMPILE_FORCE_FULLGRAPH", "")
  2599. force_fullgraph = force_fullgraph.lower() in ("yes", "true", "on", "t", "y", "1")
  2600. if force_fullgraph:
  2601. import torch
  2602. orig_method = torch.compile
  2603. def patched(*args, **kwargs):
  2604. # In `torch_compile`, all arguments except `model` is keyword only argument.
  2605. kwargs["fullgraph"] = True
  2606. return orig_method(*args, **kwargs)
  2607. torch.compile = patched
  2608. def _get_test_info():
  2609. """
  2610. Collect some information about the current test.
  2611. For example, test full name, line number, stack, traceback, etc.
  2612. """
  2613. full_test_name = os.environ.get("PYTEST_CURRENT_TEST", "").split(" ")[0]
  2614. test_file, test_class, test_name = full_test_name.split("::")
  2615. # from the most recent frame to the top frame
  2616. stack_from_inspect = inspect.stack()
  2617. # but visit from the top frame to the most recent frame
  2618. actual_test_file, _actual_test_class = test_file, test_class
  2619. test_frame, test_obj, test_method = None, None, None
  2620. for frame in reversed(stack_from_inspect):
  2621. # if test_file in str(frame).replace(r"\\", "/"):
  2622. # check frame's function + if it has `self` as locals; double check if self has the (function) name
  2623. # TODO: Question: How about expanded?
  2624. if (
  2625. test_name.startswith(frame.function)
  2626. and "self" in frame.frame.f_locals
  2627. and hasattr(frame.frame.f_locals["self"], test_name)
  2628. ):
  2629. # if test_name == frame.frame.f_locals["self"]._testMethodName:
  2630. test_frame = frame
  2631. # The test instance
  2632. test_obj = frame.frame.f_locals["self"]
  2633. # TODO: Do we get the (relative?) path or it's just a file name?
  2634. # TODO: Does `test_obj` always have `tearDown` object?
  2635. actual_test_file = frame.filename
  2636. # TODO: check `test_method` will work used at the several places!
  2637. test_method = getattr(test_obj, test_name)
  2638. break
  2639. if test_frame is not None:
  2640. line_number = test_frame.lineno
  2641. # The frame of `patched` being called (the one and the only one calling `_get_test_info`)
  2642. # This is used to get the original method being patched in order to get the context.
  2643. frame_of_patched_obj = None
  2644. captured_frames = []
  2645. to_capture = False
  2646. # From the most outer (i.e. python's `runpy.py`) frame to most inner frame (i.e. the frame of this method)
  2647. # Between `the test method being called` and `before entering `patched``.
  2648. for frame in reversed(stack_from_inspect):
  2649. if (
  2650. test_name.startswith(frame.function)
  2651. and "self" in frame.frame.f_locals
  2652. and hasattr(frame.frame.f_locals["self"], test_name)
  2653. ):
  2654. to_capture = True
  2655. # TODO: check simply with the name is not robust.
  2656. elif frame.frame.f_code.co_name == "patched":
  2657. frame_of_patched_obj = frame
  2658. to_capture = False
  2659. break
  2660. if to_capture:
  2661. captured_frames.append(frame)
  2662. tb_next = None
  2663. for frame_info in reversed(captured_frames):
  2664. tb = types.TracebackType(tb_next, frame_info.frame, frame_info.frame.f_lasti, frame_info.frame.f_lineno)
  2665. tb_next = tb
  2666. test_traceback = tb
  2667. origin_method_being_patched = frame_of_patched_obj.frame.f_locals["orig_method"]
  2668. # An iterable of type `traceback.StackSummary` with each element of type `FrameSummary`
  2669. stack = traceback.extract_stack()
  2670. # The frame which calls `the original method being patched`
  2671. caller_frame = None
  2672. # From the most inner (i.e. recent) frame to the most outer frame
  2673. for frame in reversed(stack):
  2674. if origin_method_being_patched.__name__ in frame.line:
  2675. caller_frame = frame
  2676. caller_path = os.path.relpath(caller_frame.filename)
  2677. caller_lineno = caller_frame.lineno
  2678. test_lineno = line_number
  2679. # Get the code context in the test function/method.
  2680. from _pytest._code.source import Source
  2681. with open(actual_test_file) as fp:
  2682. s = fp.read()
  2683. source = Source(s)
  2684. test_code_context = "\n".join(source.getstatement(test_lineno - 1).lines)
  2685. # Get the code context in the caller (to the patched function/method).
  2686. with open(caller_path) as fp:
  2687. s = fp.read()
  2688. source = Source(s)
  2689. caller_code_context = "\n".join(source.getstatement(caller_lineno - 1).lines)
  2690. test_info = f"test:\n\n{full_test_name}\n\n{'-' * 80}\n\ntest context: {actual_test_file}:{test_lineno}\n\n{test_code_context}"
  2691. test_info = f"{test_info}\n\n{'-' * 80}\n\ncaller context: {caller_path}:{caller_lineno}\n\n{caller_code_context}"
  2692. return (
  2693. full_test_name,
  2694. test_file,
  2695. test_lineno,
  2696. test_obj,
  2697. test_method,
  2698. test_frame,
  2699. test_traceback,
  2700. test_code_context,
  2701. caller_path,
  2702. caller_lineno,
  2703. caller_code_context,
  2704. test_info,
  2705. )
  2706. def _get_call_arguments(code_context):
  2707. """
  2708. Analyze the positional and keyword arguments in a call expression.
  2709. This will extract the expressions of the positional and kwyword arguments, and associate them to the positions and
  2710. the keyword argument names.
  2711. """
  2712. def get_argument_name(node):
  2713. """Extract the name/expression from an AST node"""
  2714. if isinstance(node, ast.Name):
  2715. return node.id
  2716. elif isinstance(node, ast.Attribute):
  2717. return ast.unparse(node)
  2718. elif isinstance(node, ast.Constant):
  2719. return repr(node.value)
  2720. else:
  2721. return ast.unparse(node)
  2722. indent = len(code_context) - len(code_context.lstrip())
  2723. code_context = code_context.replace(" " * indent, "")
  2724. try:
  2725. # Parse the line
  2726. tree = ast.parse(code_context, mode="eval")
  2727. assert isinstance(tree.body, ast.Call)
  2728. call_node = tree.body
  2729. if call_node:
  2730. result = {
  2731. "positional_args": [],
  2732. "keyword_args": {},
  2733. "starargs": None, # *args
  2734. "kwargs": None, # **kwargs
  2735. }
  2736. # Extract positional arguments
  2737. for arg in call_node.args:
  2738. arg_name = get_argument_name(arg)
  2739. result["positional_args"].append(arg_name)
  2740. # Extract keyword arguments
  2741. for keyword in call_node.keywords:
  2742. if keyword.arg is None:
  2743. # This is **kwargs
  2744. result["kwargs"] = get_argument_name(keyword.value)
  2745. else:
  2746. # Regular keyword argument
  2747. arg_name = get_argument_name(keyword.value)
  2748. result["keyword_args"][keyword.arg] = arg_name
  2749. return result
  2750. except (SyntaxError, AttributeError) as e:
  2751. print(f"Error parsing: {e}")
  2752. return None
  2753. def _prepare_debugging_info(test_info, info):
  2754. """Combine the information about the test and the call information to a patched function/method within it."""
  2755. info = f"{test_info}\n\n{info}"
  2756. p = os.path.join(os.environ.get("_PATCHED_TESTING_METHODS_OUTPUT_DIR", ""), "captured_info.txt")
  2757. # TODO (ydshieh): This is not safe when we use pytest-xdist with more than 1 worker.
  2758. with open(p, "a") as fp:
  2759. fp.write(f"{info}\n\n{'=' * 120}\n\n")
  2760. return info
  2761. def _patched_tearDown(self, *args, **kwargs):
  2762. """Used to report a test that has failures captured and handled by patched functions/methods (without re-raise).
  2763. The patched functions/methods refer to the `patched` defined in `_patch_with_call_info`, which is applied to
  2764. `torch.testing.assert_close` and `unittest.case.TestCase.assertEqual`.
  2765. The objective is to avoid a failure being silence after being processed.
  2766. If there is any failure that is not handled by the patched functions/methods, we add custom error message for them
  2767. along with the usual pytest failure report.
  2768. """
  2769. # Check for regular failures before clearing:
  2770. # when `_patched_tearDown` is called, the current test fails due to an assertion error given by a method being
  2771. # patched by `_patch_with_call_info`. The patched method catches such an error and continue running the remaining
  2772. # statements within the test. If the test fails with another error not handled by the patched methods, we don't let
  2773. # pytest to fail and report it but the original failure (the first one that was processed) instead.
  2774. # We still record those failures not handled by the patched methods, and add custom messages along with the usual
  2775. # pytest failure report.
  2776. regular_failures_info = []
  2777. errors = None
  2778. if hasattr(self._outcome, "errors"):
  2779. errors = self._outcome.errors
  2780. elif hasattr(self._outcome, "result") and hasattr(self._outcome.result, "errors"):
  2781. errors = self._outcome.result.errors
  2782. if hasattr(self, "_outcome") and errors:
  2783. for error_entry in errors:
  2784. test_instance, (exc_type, exc_obj, exc_tb) = error_entry
  2785. # breakpoint()
  2786. regular_failures_info.append(
  2787. {
  2788. "message": f"{str(exc_obj)}\n\n",
  2789. "type": exc_type.__name__,
  2790. "file": "test_modeling_vit.py",
  2791. "line": 237, # get_deepest_frame_line(exc_tb) # Your helper function
  2792. }
  2793. )
  2794. # Clear the regular failure (i.e. that is not from any of our patched assertion methods) from pytest's records.
  2795. if hasattr(self._outcome, "errors"):
  2796. self._outcome.errors.clear()
  2797. elif hasattr(self._outcome, "result") and hasattr(self._outcome.result, "errors"):
  2798. self._outcome.result.errors.clear()
  2799. # reset back to the original tearDown method, so `_patched_tearDown` won't be run by the subsequent tests if they
  2800. # have only test failures that are not handle by the patched methods (or no test failure at all).
  2801. orig_tearDown = _patched_tearDown.orig_tearDown
  2802. type(self).tearDown = orig_tearDown
  2803. # Call the original tearDown
  2804. orig_tearDown(self, *args, **kwargs)
  2805. # Get the failure
  2806. test_method = getattr(self, self._testMethodName)
  2807. captured_failures = test_method.__func__.captured_failures[id(test_method)]
  2808. # TODO: How could we show several exceptions in a sinigle test on the terminal? (Maybe not a good idea)
  2809. captured_exceptions = captured_failures[0]["exception"]
  2810. captured_traceback = captured_failures[0]["traceback"]
  2811. # Show the captured information on the terminal.
  2812. capturued_info = [x["info"] for x in captured_failures]
  2813. capturued_info_str = f"\n\n{'=' * 80}\n\n".join(capturued_info)
  2814. # Enhance the exception message if there were suppressed failures
  2815. if regular_failures_info:
  2816. enhanced_message = f"""{str(captured_exceptions)}
  2817. {"=" * 80}
  2818. Handled Failures: ({len(capturued_info)} handled):
  2819. {"-" * 80}\n
  2820. {capturued_info_str}
  2821. {"=" * 80}
  2822. Unhandled Failures: ({len(regular_failures_info)} unhandled):
  2823. {"-" * 80}\n
  2824. {", ".join(f"{info['type']}: {info['message']}{info['file']}:{info['line']}" for info in regular_failures_info)}
  2825. {"-" * 80}
  2826. Note: This failure occurred after other failures analyzed by the patched assertion methods.
  2827. To see the full details, temporarily disable assertion patching.
  2828. {"=" * 80}"""
  2829. # Create new exception with enhanced message
  2830. enhanced_exception = type(captured_exceptions)(enhanced_message)
  2831. enhanced_exception.__cause__ = captured_exceptions.__cause__
  2832. enhanced_exception.__context__ = captured_exceptions.__context__
  2833. # Raise with your existing traceback reconstruction
  2834. captured_exceptions = enhanced_exception
  2835. # clean up the recorded status
  2836. del test_method.__func__.captured_failures
  2837. raise captured_exceptions.with_traceback(captured_traceback)
  2838. def _patch_with_call_info(module_or_class, attr_name, _parse_call_info_func, target_args):
  2839. """
  2840. Patch a callerable `attr_name` of a module or class `module_or_class`.
  2841. This will allow us to collect the call information, e.g. the argument names and values, also the literal expressions
  2842. passed as the arguments.
  2843. """
  2844. orig_method = getattr(module_or_class, attr_name)
  2845. if not callable(orig_method):
  2846. return
  2847. def patched(*args, **kwargs):
  2848. # If the target callable is not called within a test, simply call it without modification.
  2849. if not os.environ.get("PYTEST_CURRENT_TEST", ""):
  2850. return orig_method(*args, **kwargs)
  2851. try:
  2852. orig_method(*args, **kwargs)
  2853. except AssertionError as e:
  2854. captured_exception = e
  2855. # captured_traceback = e.__traceback__
  2856. (
  2857. full_test_name,
  2858. test_file,
  2859. test_lineno,
  2860. test_obj,
  2861. test_method,
  2862. test_frame,
  2863. test_traceback,
  2864. test_code_context,
  2865. caller_path,
  2866. caller_lineno,
  2867. caller_code_context,
  2868. test_info,
  2869. ) = _get_test_info()
  2870. test_info = f"{test_info}\n\n{'-' * 80}\n\npatched method: {orig_method.__module__}.{orig_method.__name__}"
  2871. call_argument_expressions = _get_call_arguments(caller_code_context)
  2872. # This is specific
  2873. info = _parse_call_info_func(orig_method, args, kwargs, call_argument_expressions, target_args)
  2874. info = _prepare_debugging_info(test_info, info)
  2875. # If the test is running in a CI environment (e.g. not a manual run), let's raise and fail the test, so it
  2876. # behaves as usual.
  2877. # On Github Actions or CircleCI, this is set automatically.
  2878. # When running manually, it's the user to determine if to set it.
  2879. # This is to avoid the patched function being called `with self.assertRaises(AssertionError):` and fails
  2880. # because of the missing expected `AssertionError`.
  2881. # TODO (ydshieh): If there is way to raise only when we are inside such context managers?
  2882. # TODO (ydshieh): How not to record the failure if it happens inside `self.assertRaises(AssertionError)`?
  2883. if os.getenv("CI") == "true":
  2884. raise captured_exception.with_traceback(test_traceback)
  2885. # Save this, so we can raise at the end of the current test
  2886. captured_failure = {
  2887. "result": "failed",
  2888. "exception": captured_exception,
  2889. "traceback": test_traceback,
  2890. "info": info,
  2891. }
  2892. # Record the failure status and its information, so we can raise it later.
  2893. # We are modifying the (unbound) function at class level: not its logic but only adding a new extra
  2894. # attribute.
  2895. if getattr(test_method.__func__, "captured_failures", None) is None:
  2896. test_method.__func__.captured_failures = {}
  2897. if id(test_method) not in test_method.__func__.captured_failures:
  2898. test_method.__func__.captured_failures[id(test_method)] = []
  2899. test_method.__func__.captured_failures[id(test_method)].append(captured_failure)
  2900. # This modifies the `tearDown` which will be called after every tests, but we reset it back inside
  2901. # `_patched_tearDown`.
  2902. if not hasattr(type(test_obj).tearDown, "orig_tearDown"):
  2903. orig_tearDown = type(test_obj).tearDown
  2904. _patched_tearDown.orig_tearDown = orig_tearDown
  2905. type(test_obj).tearDown = _patched_tearDown
  2906. setattr(module_or_class, attr_name, patched)
  2907. def _parse_call_info(func, args, kwargs, call_argument_expressions, target_args):
  2908. """
  2909. Prepare a string containing the call info to `func`, e.g. argument names/values/expressions.
  2910. """
  2911. signature = inspect.signature(func)
  2912. signature_names = [param.name for param_name, param in signature.parameters.items()]
  2913. # called as `self.method_name()` or `xxx.method_name()`.
  2914. if len(args) == len(call_argument_expressions["positional_args"]) + 1:
  2915. # We simply add "self" as the expression despite it might not be the actual argument name.
  2916. # (This part is very unlikely what a user would be interest to know)
  2917. call_argument_expressions["positional_args"] = ["self"] + call_argument_expressions["positional_args"]
  2918. param_position_mapping = {param_name: idx for idx, param_name in enumerate(signature_names)}
  2919. arg_info = {}
  2920. for arg_name in target_args:
  2921. if arg_name in kwargs:
  2922. arg_value = kwargs[arg_name]
  2923. arg_expr = call_argument_expressions["keyword_args"][arg_name]
  2924. else:
  2925. arg_pos = param_position_mapping[arg_name]
  2926. arg_value = args[arg_pos]
  2927. arg_expr = call_argument_expressions["positional_args"][arg_pos]
  2928. arg_value_str = _format_py_obj(arg_value)
  2929. arg_info[arg_name] = {"arg_expr": arg_expr, "arg_value_str": arg_value_str}
  2930. info = ""
  2931. for arg_name in arg_info:
  2932. arg_expr, arg_value_str = arg_info[arg_name]["arg_expr"], arg_info[arg_name]["arg_value_str"]
  2933. info += f"{'-' * 80}\n\nargument name: `{arg_name}`\nargument expression: `{arg_expr}`\n\nargument value:\n\n{arg_value_str}\n\n"
  2934. # remove the trailing \n\n
  2935. info = info[:-2]
  2936. return info
  2937. def patch_testing_methods_to_collect_info():
  2938. """
  2939. Patch some methods (`torch.testing.assert_close`, `unittest.case.TestCase.assertEqual`, etc).
  2940. This will allow us to collect the call information, e.g. the argument names and values, also the literal expressions
  2941. passed as the arguments.
  2942. """
  2943. p = os.path.join(os.environ.get("_PATCHED_TESTING_METHODS_OUTPUT_DIR", ""), "captured_info.txt")
  2944. Path(p).unlink(missing_ok=True)
  2945. if is_torch_available():
  2946. import torch
  2947. _patch_with_call_info(torch.testing, "assert_close", _parse_call_info, target_args=("actual", "expected"))
  2948. _patch_with_call_info(unittest.case.TestCase, "assertEqual", _parse_call_info, target_args=("first", "second"))
  2949. _patch_with_call_info(unittest.case.TestCase, "assertListEqual", _parse_call_info, target_args=("list1", "list2"))
  2950. _patch_with_call_info(
  2951. unittest.case.TestCase, "assertTupleEqual", _parse_call_info, target_args=("tuple1", "tuple2")
  2952. )
  2953. _patch_with_call_info(unittest.case.TestCase, "assertSetEqual", _parse_call_info, target_args=("set1", "set1"))
  2954. _patch_with_call_info(unittest.case.TestCase, "assertDictEqual", _parse_call_info, target_args=("d1", "d2"))
  2955. _patch_with_call_info(unittest.case.TestCase, "assertIn", _parse_call_info, target_args=("member", "container"))
  2956. _patch_with_call_info(unittest.case.TestCase, "assertNotIn", _parse_call_info, target_args=("member", "container"))
  2957. _patch_with_call_info(unittest.case.TestCase, "assertLess", _parse_call_info, target_args=("a", "b"))
  2958. _patch_with_call_info(unittest.case.TestCase, "assertLessEqual", _parse_call_info, target_args=("a", "b"))
  2959. _patch_with_call_info(unittest.case.TestCase, "assertGreater", _parse_call_info, target_args=("a", "b"))
  2960. _patch_with_call_info(unittest.case.TestCase, "assertGreaterEqual", _parse_call_info, target_args=("a", "b"))
  2961. def torchrun(script: str, nproc_per_node: int, is_torchrun: bool = True, env: dict | None = None):
  2962. """Run the `script` using `torchrun` command for multi-processing in a subprocess. Captures errors as necessary."""
  2963. with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp:
  2964. tmp.write(script)
  2965. tmp.flush()
  2966. tmp.seek(0)
  2967. if is_torchrun:
  2968. cmd = (
  2969. f"torchrun --nproc_per_node {nproc_per_node} --master_port {get_torch_dist_unique_port()} {tmp.name}"
  2970. ).split()
  2971. else:
  2972. cmd = ["python3", tmp.name]
  2973. # Note that the subprocess will be waited for here, and raise an error if not successful
  2974. try:
  2975. _ = subprocess.run(cmd, capture_output=True, env=env, text=True, check=True)
  2976. except subprocess.CalledProcessError as e:
  2977. raise Exception(f"The following error was captured: {e.stderr}")
  2978. def _format_tensor(t, indent_level=0, sci_mode=None):
  2979. """Format torch's tensor in a pretty way to be shown 👀 in the test report."""
  2980. # `torch.testing.assert_close` could accept python int/float numbers.
  2981. if not isinstance(t, torch.Tensor):
  2982. t = torch.tensor(t)
  2983. # Simply make the processing below simpler (not to handle both cases)
  2984. is_scalar = False
  2985. if t.ndim == 0:
  2986. t = torch.tensor([t])
  2987. is_scalar = True
  2988. # For scalar or one-dimensional tensor, keep it as one-line. If there is only one element along any dimension except
  2989. # the last one, we also keep it as one-line.
  2990. if t.ndim <= 1 or set(t.shape[0:-1]) == {1}:
  2991. # Use `detach` to remove `grad_fn=<...>`, and use `to("cpu")` to remove `device='...'`
  2992. t = t.detach().to("cpu")
  2993. # We work directly with the string representation instead the tensor itself
  2994. t_str = str(t)
  2995. # remove `tensor( ... )` so keep only the content
  2996. t_str = t_str.replace("tensor(", "").replace(")", "")
  2997. # Sometimes there are extra spaces between `[` and the first digit of the first value (for alignment).
  2998. # For example `[[ 0.06, -0.51], [-0.76, -0.49]]`. It may have multiple consecutive spaces.
  2999. # Let's remove such extra spaces.
  3000. while "[ " in t_str:
  3001. t_str = t_str.replace("[ ", "[")
  3002. # Put everything in a single line. We replace `\n` by a space ` ` so we still keep `,\n` as `, `.
  3003. t_str = t_str.replace("\n", " ")
  3004. # Remove repeated spaces (introduced by the previous step)
  3005. while " " in t_str:
  3006. t_str = t_str.replace(" ", " ")
  3007. # remove leading `[` and `]` for scalar tensor
  3008. if is_scalar:
  3009. t_str = t_str[1:-1]
  3010. t_str = " " * 4 * indent_level + t_str
  3011. return t_str
  3012. # Otherwise, we separate the representations of each element along an outer dimension by new lines (after a `,`).
  3013. # The representation of each element is obtained by calling this function recursively with current `indent_level`.
  3014. else:
  3015. t_str = str(t)
  3016. # (For the recursive calls should receive this value)
  3017. if sci_mode is None:
  3018. sci_mode = "e+" in t_str or "e-" in t_str
  3019. # Use the original content to determine the scientific mode to use. This is required as the representation of
  3020. # t[index] (computed below) maybe have different format regarding scientific notation.
  3021. torch.set_printoptions(sci_mode=sci_mode)
  3022. t_str = " " * 4 * indent_level + "[\n"
  3023. # Keep the ending `,` for all outer dimensions whose representations are not put in one-line, even if there is
  3024. # only one element along that dimension.
  3025. t_str += ",\n".join(_format_tensor(x, indent_level=indent_level + 1, sci_mode=sci_mode) for x in t)
  3026. t_str += ",\n" + " " * 4 * indent_level + "]"
  3027. torch.set_printoptions(sci_mode=None)
  3028. return t_str
  3029. def _quote_string(s):
  3030. """Given a string `s`, return a python literal expression that give `s` when it is used in a python source code.
  3031. For example, if `s` is the string `abc`, the return value is `"abc"`.
  3032. We choice double quotes over single quote despite `str(s)` would give `'abc'` instead of `"abc"`.
  3033. """
  3034. has_single_quote = "'" in s
  3035. has_double_quote = '"' in s
  3036. if has_single_quote and has_double_quote:
  3037. # replace any double quote by the raw string r'\"'.
  3038. s = s.replace('"', r"\"")
  3039. return f'"{s}"'
  3040. elif has_single_quote:
  3041. return f'"{s}"'
  3042. elif has_double_quote:
  3043. return f"'{s}'"
  3044. else:
  3045. return f'"{s}"'
  3046. def _format_py_obj(obj, indent=0, mode="", cache=None, prefix=""):
  3047. """Format python objects of basic built-in type in a pretty way so we could copy-past them to code editor easily.
  3048. Currently, this support int, float, str, list, tuple, and dict.
  3049. It also works with `torch.Tensor` via calling `format_tesnor`.
  3050. """
  3051. if cache is None:
  3052. cache = {}
  3053. else:
  3054. if (id(obj), indent, mode, prefix) in cache:
  3055. return cache[(id(obj), indent, mode, prefix)]
  3056. # special format method for `torch.Tensor`
  3057. if str(obj.__class__) == "<class 'torch.Tensor'>":
  3058. return _format_tensor(obj)
  3059. elif obj.__class__.__name__ == "str":
  3060. quoted_string = _quote_string(obj)
  3061. # we don't want the newline being interpreted
  3062. quoted_string = quoted_string.replace("\n", r"\n")
  3063. output = quoted_string
  3064. elif obj.__class__.__name__ in ["int", "float"]:
  3065. # for float like `1/3`, we will get `0.3333333333333333`
  3066. output = str(obj)
  3067. elif obj.__class__.__name__ in ["list", "tuple", "dict"]:
  3068. parenthesis = {
  3069. "list": "[]",
  3070. "tuple": "()",
  3071. "dict": "{}",
  3072. }
  3073. p1, p2 = parenthesis[obj.__class__.__name__]
  3074. elements_without_indent = []
  3075. if isinstance(obj, dict):
  3076. for idx, (k, v) in enumerate(obj.items()):
  3077. last_element = idx == len(obj) - 1
  3078. ok = _format_py_obj(k, indent=indent + 1, mode="one-line", cache=cache)
  3079. ov = _format_py_obj(
  3080. v,
  3081. indent=indent + 1,
  3082. mode=mode,
  3083. cache=cache,
  3084. prefix=ok.lstrip() + ": " + "," if not last_element else "",
  3085. )
  3086. # Each element could be multiple-line, but the indent of its first line is removed
  3087. elements_without_indent.append(f"{ok.lstrip()}: {ov.lstrip()}")
  3088. else:
  3089. for idx, x in enumerate(obj):
  3090. last_element = idx == len(obj) - 1
  3091. o = _format_py_obj(
  3092. x, indent=indent + 1, mode=mode, cache=cache, prefix="," if not last_element else ""
  3093. )
  3094. # Each element could be multiple-line, but the indent of its first line is removed
  3095. elements_without_indent.append(o.lstrip())
  3096. groups = []
  3097. buf = []
  3098. for idx, x in enumerate(elements_without_indent):
  3099. buf.append(x)
  3100. x_expanded = "\n" in buf[-1]
  3101. not_last_element = idx != len(elements_without_indent) - 1
  3102. # if `x` should be separated from subsequent elements
  3103. should_finalize_x = x_expanded or len(f"{' ' * (4 * (indent + 1))}") + len(
  3104. ", ".join(buf[-1:])
  3105. ) > 120 - int(not_last_element)
  3106. # if `buf[:-1]` (i.e. without `x`) should be combined together (into one line)
  3107. should_finalize_buf = x_expanded
  3108. # the recursive call returns single line, so we can use it to determine if we can fit the width limit
  3109. if not should_finalize_buf:
  3110. buf_not_fit_into_one_line = len(f"{' ' * (4 * (indent + 1))}") + len(", ".join(buf)) > 120 - int(
  3111. not_last_element
  3112. )
  3113. should_finalize_buf = buf_not_fit_into_one_line
  3114. # any element of iterable type need to be on its own line
  3115. if (type(obj[idx]) if type(obj) is not dict else type(list(obj.values())[idx])) in [list, tuple, dict]:
  3116. should_finalize_x = True
  3117. should_finalize_buf = True
  3118. # any type change --> need to be added after a new line
  3119. prev_type = None
  3120. current_type = type(obj[idx]) if type(obj) is not dict else type(list(obj.values())[idx])
  3121. if len(buf) > 1:
  3122. prev_type = type(obj[idx - 1]) if type(obj) is not dict else type(list(obj.values())[idx - 1])
  3123. type_changed = current_type != prev_type
  3124. if type_changed:
  3125. should_finalize_buf = True
  3126. # all elements in the buf are string --> don't finalize the buf by width limit
  3127. if prev_type is None or (prev_type is str and current_type is str):
  3128. should_finalize_buf = False
  3129. # collect as many elements of string type as possible (without width limit).
  3130. # These will be examined as a whole (if not fit into the width, each element would be in its own line)
  3131. if current_type is str:
  3132. should_finalize_x = False
  3133. # `len(buf) == 1` or `obj[idx-1]` is a string
  3134. if prev_type in [None, str]:
  3135. should_finalize_buf = False
  3136. if should_finalize_buf:
  3137. orig_buf_len = len(buf)
  3138. if orig_buf_len > 1:
  3139. not_fit_into_one_line = None
  3140. # all elements in `obj` that give `buf[:-1]` are string.
  3141. if prev_type is str:
  3142. # `-1` at the end: because buf[-2] is not the last element
  3143. not_fit_into_one_line = len(f"{' ' * (4 * (indent + 1))}") + len(", ".join(buf[:-1])) > 120 - 1
  3144. if not_fit_into_one_line:
  3145. for x in buf[:-1]:
  3146. groups.append([x])
  3147. else:
  3148. groups.append(buf[:-1])
  3149. buf = buf[-1:]
  3150. if should_finalize_x:
  3151. groups.append(buf)
  3152. buf = []
  3153. # The last buf
  3154. if len(buf) > 0:
  3155. not_fit_into_one_line = None
  3156. if current_type is str:
  3157. # no `-1` at the end: because buf[-1] is the last element
  3158. not_fit_into_one_line = len(f"{' ' * (4 * (indent + 1))}") + len(", ".join(buf)) > 120
  3159. if not_fit_into_one_line:
  3160. for x in buf:
  3161. groups.append([x])
  3162. else:
  3163. groups.append(buf)
  3164. output = f"{' ' * 4 * indent}{p1}\n"
  3165. element_strings = [f"{' ' * (4 * (indent + 1))}" + ", ".join(buf) for buf in groups]
  3166. output += ",\n".join(element_strings)
  3167. output += f"\n{' ' * 4 * indent}{p2}"
  3168. # if all elements are in one-line
  3169. no_new_line_in_elements = all("\n" not in x for x in element_strings)
  3170. # if yes, we can form a one-line representation of `obj`
  3171. could_use_one_line = no_new_line_in_elements
  3172. # if mode == "one-line", this function always returns one-line representation, so `no_new_line_in_elements`
  3173. # will be `True`.
  3174. if could_use_one_line:
  3175. one_line_form = ", ".join([x.lstrip() for x in element_strings])
  3176. one_line_form = f"{p1}{one_line_form}{p2}"
  3177. if mode == "one-line":
  3178. return output
  3179. # check with the width limit
  3180. could_use_one_line = len(f"{' ' * 4 * indent}") + len(prefix) + len(one_line_form) <= 120
  3181. # extra conditions for returning one-line representation
  3182. def use_one_line_repr(obj):
  3183. # iterable types
  3184. if type(obj) in (list, tuple, dict):
  3185. # get all types
  3186. element_types = []
  3187. if type(obj) is dict:
  3188. element_types.extend(type(x) for x in obj.values())
  3189. elif type(obj) in [list, tuple]:
  3190. element_types.extend(type(x) for x in obj)
  3191. # At least one element is of iterable type
  3192. if any(x in (list, tuple, dict) for x in element_types):
  3193. # If `obj` has more than one element and at least one of them is iterable --> no one line repr.
  3194. if len(obj) > 1:
  3195. return False
  3196. # only one element that is iterable, but not the same type as `obj` --> no one line repr.
  3197. if type(obj) is not type(obj[0]):
  3198. return False
  3199. # one-line repr. if possible, without width limit
  3200. return no_new_line_in_elements
  3201. # all elements are of simple types, but more than one type --> no one line repr.
  3202. if len(set(element_types)) > 1:
  3203. return False
  3204. # all elements are of the same simple type
  3205. if element_types[0] in [int, float]:
  3206. # one-line repr. without width limit
  3207. return no_new_line_in_elements
  3208. elif element_types[0] is str:
  3209. if len(obj) == 1:
  3210. # one single string element --> one-line repr. without width limit
  3211. return no_new_line_in_elements
  3212. else:
  3213. # multiple string elements --> one-line repr. if fit into width limit
  3214. return could_use_one_line
  3215. # simple types (int, flat, string)
  3216. return True
  3217. # width condition combined with specific mode conditions
  3218. if use_one_line_repr(obj):
  3219. output = f"{' ' * 4 * indent}{one_line_form}"
  3220. cache[(id(obj), indent, mode, prefix)] = output
  3221. return output
  3222. def write_file(file, content):
  3223. with open(file, "w") as f:
  3224. f.write(content)
  3225. def read_json_file(file):
  3226. with open(file, "r") as fh:
  3227. return json.load(fh)
  3228. # =============================================================================
  3229. # Training CI Utilities - Logging and Memory Monitoring
  3230. # =============================================================================
  3231. # ANSI color codes for terminal output
  3232. class Colors:
  3233. """ANSI color codes for terminal output formatting."""
  3234. RESET = "\033[0m"
  3235. BOLD = "\033[1m"
  3236. DIM = "\033[2m"
  3237. # Foreground colors
  3238. RED = "\033[31m"
  3239. GREEN = "\033[32m"
  3240. YELLOW = "\033[33m"
  3241. BLUE = "\033[34m"
  3242. MAGENTA = "\033[35m"
  3243. CYAN = "\033[36m"
  3244. WHITE = "\033[37m"
  3245. # Bright variants
  3246. BRIGHT_RED = "\033[91m"
  3247. BRIGHT_GREEN = "\033[92m"
  3248. BRIGHT_YELLOW = "\033[93m"
  3249. BRIGHT_BLUE = "\033[94m"
  3250. BRIGHT_CYAN = "\033[96m"
  3251. class ColoredFormatter(logging.Formatter):
  3252. """Custom formatter that adds colors based on log level."""
  3253. LEVEL_COLORS = {
  3254. logging.DEBUG: Colors.DIM + Colors.CYAN,
  3255. logging.INFO: Colors.WHITE,
  3256. logging.WARNING: Colors.BRIGHT_YELLOW,
  3257. logging.ERROR: Colors.BRIGHT_RED,
  3258. logging.CRITICAL: Colors.BOLD + Colors.BRIGHT_RED,
  3259. }
  3260. # Loggers that should be dimmed (less important/verbose)
  3261. DIMMED_LOGGERS = {"httpx", "httpcore", "urllib3", "requests"}
  3262. def __init__(self, fmt: str | None = None, datefmt: str | None = None):
  3263. super().__init__(fmt, datefmt)
  3264. def format(self, record: logging.LogRecord) -> str:
  3265. # Check if this logger should be dimmed
  3266. is_dimmed = record.name in self.DIMMED_LOGGERS
  3267. if is_dimmed:
  3268. # Dim the entire log line for httpx and similar
  3269. timestamp = self.formatTime(record, self.datefmt)
  3270. message = record.getMessage()
  3271. return f"{Colors.DIM}{timestamp} - {record.name} - {record.levelname:8} - {message}{Colors.RESET}"
  3272. # Get color for this level
  3273. color = self.LEVEL_COLORS.get(record.levelno, Colors.RESET)
  3274. # Color the level name
  3275. levelname = record.levelname
  3276. colored_levelname = f"{color}{levelname:8}{Colors.RESET}"
  3277. # Color the timestamp
  3278. colored_time = f"{Colors.DIM}{self.formatTime(record, self.datefmt)}{Colors.RESET}"
  3279. # Color the logger name
  3280. colored_name = f"{Colors.BLUE}{record.name}{Colors.RESET}"
  3281. # Get message
  3282. message = record.getMessage()
  3283. return f"{colored_time} - {colored_name} - {colored_levelname} - {message}"
  3284. _warn_once_logged: set[str] = set()
  3285. def init_test_logger() -> logging.Logger:
  3286. """Initialize a test-specific logger with colored stderr handler and INFO level for tests.
  3287. Uses a named logger instead of root logger to avoid conflicts with pytest-xdist parallel execution.
  3288. Uses stderr instead of stdout to avoid deadlocks with pytest-xdist output capture.
  3289. """
  3290. logger = logging.getLogger("transformers.training_test")
  3291. logger.setLevel(logging.INFO)
  3292. # Only add handler if not already present (avoid duplicate handlers on repeated calls)
  3293. if not logger.handlers:
  3294. # Use stderr instead of stdout - pytest-xdist captures stdout which can cause deadlocks
  3295. ch = logging.StreamHandler(sys.stderr)
  3296. ch.setLevel(logging.INFO)
  3297. # Use colored formatter if terminal supports it, plain otherwise
  3298. if sys.stderr.isatty():
  3299. formatter = ColoredFormatter(datefmt="%Y-%m-%d %H:%M:%S")
  3300. else:
  3301. formatter = logging.Formatter(
  3302. "%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
  3303. )
  3304. ch.setFormatter(formatter)
  3305. logger.addHandler(ch)
  3306. logger.propagate = False # Don't propagate to root logger to avoid duplicate output
  3307. return logger
  3308. def warn_once(logger_instance: logging.Logger, msg: str) -> None:
  3309. """Log a warning message only once per unique message.
  3310. Uses a global set to track messages that have already been logged
  3311. to prevent duplicate warning messages from cluttering the output.
  3312. Args:
  3313. logger_instance: The logger instance to use for warning.
  3314. msg: The warning message to log.
  3315. """
  3316. if msg not in _warn_once_logged:
  3317. logger_instance.warning(msg)
  3318. _warn_once_logged.add(msg)
  3319. # Named tuple for passing memory stats for logging
  3320. MemoryStats = collections.namedtuple(
  3321. "MemoryStats",
  3322. [
  3323. "rss_gib", # Resident Set Size in GiB
  3324. "rss_pct", # RSS as percentage of total memory
  3325. "vms_gib", # Virtual Memory Size in GiB
  3326. "peak_rss_gib", # Peak RSS in GiB
  3327. "peak_rss_pct", # Peak RSS as percentage of total memory
  3328. "available_gib", # Available system memory in GiB
  3329. "total_gib", # Total system memory in GiB
  3330. ],
  3331. )
  3332. class CPUMemoryMonitor:
  3333. """Monitor CPU memory usage for the current process."""
  3334. def __init__(self):
  3335. self.device_name = "CPU"
  3336. self._peak_rss = 0
  3337. self._process = None
  3338. self.total_memory = 0
  3339. self.total_memory_gib = 0
  3340. if is_psutil_available():
  3341. import psutil
  3342. self._process = psutil.Process(os.getpid())
  3343. mem_info = psutil.virtual_memory()
  3344. self.total_memory = mem_info.total
  3345. self.total_memory_gib = self._to_gib(self.total_memory)
  3346. def _to_gib(self, memory_in_bytes: int) -> float:
  3347. """Convert bytes to GiB."""
  3348. return memory_in_bytes / (1024 * 1024 * 1024)
  3349. def _to_pct(self, memory_in_bytes: int) -> float:
  3350. """Convert bytes to percentage of total memory."""
  3351. if self.total_memory == 0:
  3352. return 0.0
  3353. return 100.0 * memory_in_bytes / self.total_memory
  3354. def _update_peak(self) -> None:
  3355. """Update peak memory tracking."""
  3356. if self._process is not None:
  3357. current_rss = self._process.memory_info().rss
  3358. self._peak_rss = max(self._peak_rss, current_rss)
  3359. def get_stats(self) -> MemoryStats:
  3360. """Get current memory statistics."""
  3361. if not is_psutil_available():
  3362. return MemoryStats(0, 0, 0, 0, 0, 0, 0)
  3363. import psutil
  3364. self._update_peak()
  3365. mem_info = self._process.memory_info()
  3366. sys_mem = psutil.virtual_memory()
  3367. return MemoryStats(
  3368. rss_gib=self._to_gib(mem_info.rss),
  3369. rss_pct=self._to_pct(mem_info.rss),
  3370. vms_gib=self._to_gib(mem_info.vms),
  3371. peak_rss_gib=self._to_gib(self._peak_rss),
  3372. peak_rss_pct=self._to_pct(self._peak_rss),
  3373. available_gib=self._to_gib(sys_mem.available),
  3374. total_gib=self._to_gib(sys_mem.total),
  3375. )
  3376. def reset_peak_stats(self) -> None:
  3377. """Reset peak memory tracking."""
  3378. if self._process is not None:
  3379. self._peak_rss = self._process.memory_info().rss
  3380. def build_cpu_memory_monitor(logger_instance: logging.Logger | None = None) -> CPUMemoryMonitor:
  3381. """Build and initialize a CPU memory monitor.
  3382. Args:
  3383. logger_instance: Optional logger to log initialization info. If None, no logging is done.
  3384. Returns:
  3385. CPUMemoryMonitor instance.
  3386. """
  3387. monitor = CPUMemoryMonitor()
  3388. if logger_instance is not None:
  3389. if is_psutil_available():
  3390. logger_instance.info(f"CPU memory monitor initialized: {monitor.total_memory_gib:.2f} GiB total")
  3391. else:
  3392. logger_instance.warning("psutil not available, memory monitoring disabled")
  3393. return monitor
  3394. def convert_all_safetensors_to_bins(folder: str):
  3395. """Convert all safetensors files into torch bin files, to mimic saving with torch (since we still support loading
  3396. bin files, but not saving them anymore)"""
  3397. for file in os.listdir(folder):
  3398. path = os.path.join(folder, file)
  3399. if file.endswith(".safetensors"):
  3400. new_path = path.replace(".safetensors", ".bin").replace("model", "pytorch_model")
  3401. state_dict = load_file(path)
  3402. os.remove(path)
  3403. torch.save(state_dict, new_path)
  3404. # Adapt the index as well
  3405. elif file == SAFE_WEIGHTS_INDEX_NAME:
  3406. new_path = os.path.join(folder, WEIGHTS_INDEX_NAME)
  3407. with open(path) as f:
  3408. index = json.loads(f.read())
  3409. os.remove(path)
  3410. if "weight_map" in index.keys():
  3411. weight_map = index["weight_map"]
  3412. new_weight_map = {}
  3413. for k, v in weight_map.items():
  3414. new_weight_map[k] = v.replace(".safetensors", ".bin").replace("model", "pytorch_model")
  3415. index["weight_map"] = new_weight_map
  3416. with open(new_path, "w") as f:
  3417. f.write(json.dumps(index, indent=4))
  3418. @contextmanager
  3419. def force_serialization_as_bin_files():
  3420. """Since we don't support saving with torch `.bin` files anymore, but still support loading them, we use this context
  3421. to easily create the bin files and try to load them back"""
  3422. try:
  3423. # Monkey patch the method to save as bin files
  3424. original_save = PreTrainedModel.save_pretrained
  3425. def new_save(self, save_directory, *args, **kwargs):
  3426. original_save(self, save_directory, *args, **kwargs)
  3427. convert_all_safetensors_to_bins(save_directory)
  3428. PreTrainedModel.save_pretrained = new_save
  3429. yield
  3430. finally:
  3431. PreTrainedModel.save_pretrained = original_save