import_utils.py 111 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012
  1. # Copyright 2022 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Import utilities: Utilities related to imports and our lazy inits.
  16. """
  17. import functools
  18. import importlib.machinery
  19. import importlib.metadata
  20. import importlib.util
  21. import json
  22. import operator
  23. import os
  24. import re
  25. import shutil
  26. import subprocess
  27. import sys
  28. import warnings
  29. from collections import OrderedDict
  30. from collections.abc import Callable
  31. from enum import Enum
  32. from functools import lru_cache
  33. from itertools import chain
  34. from types import ModuleType
  35. from typing import Any
  36. import packaging.version
  37. from packaging import version
  38. from . import logging
  39. logger = logging.get_logger(__name__) # pylint: disable=invalid-name
  40. PACKAGE_DISTRIBUTION_MAPPING = importlib.metadata.packages_distributions()
  41. def _is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str]:
  42. """Check if `pkg_name` exist, and optionally try to get its version"""
  43. spec = importlib.util.find_spec(pkg_name)
  44. package_exists = spec is not None
  45. package_version = "N/A"
  46. if package_exists and return_version:
  47. try:
  48. # importlib.metadata works with the distribution package, which may be different from the import
  49. # name (e.g. `PIL` is the import name, but `pillow` is the distribution name)
  50. distributions = PACKAGE_DISTRIBUTION_MAPPING[pkg_name]
  51. # Per PEP 503, underscores and hyphens are equivalent in package names.
  52. # Prefer the distribution that matches the (normalized) package name.
  53. normalized_pkg_name = pkg_name.replace("_", "-")
  54. if normalized_pkg_name in distributions:
  55. distribution_name = normalized_pkg_name
  56. elif pkg_name in distributions:
  57. distribution_name = pkg_name
  58. else:
  59. distribution_name = distributions[0]
  60. package_version = importlib.metadata.version(distribution_name)
  61. except (importlib.metadata.PackageNotFoundError, KeyError):
  62. # If we cannot find the metadata (because of editable install for example), try to import directly.
  63. # Note that this branch will almost never be run, so we do not import packages for nothing here
  64. package = importlib.import_module(pkg_name)
  65. package_version = getattr(package, "__version__", "N/A")
  66. logger.debug(f"Detected {pkg_name} version: {package_version}")
  67. if return_version:
  68. return package_exists, package_version
  69. else:
  70. return package_exists, None
  71. def resolve_internal_import(module: ModuleType | None, chained_path: str) -> Callable | ModuleType | None:
  72. """
  73. Check if a given `module` has an internal import path as defined by the `chained_path`.
  74. This can either be the full path (not exposed in `__init__`) OR the last part of the chain (exposed in `__init__`).
  75. This is an important helper function for kernels based modules to apply the import from the module
  76. itself, i.e. stay compatible with original libraries in certain cases.
  77. Example:
  78. Module: `mamba_ssm`
  79. Chained Path: `ops.triton.selective_state_update.selective_state_update`
  80. Resulting import attempt at:
  81. - `mamba_ssm.selective_state_update`
  82. - `mamba_ssm.ops.triton.selective_state_update.selective_state_update`
  83. """
  84. if not module:
  85. return None
  86. if final_module := getattr(module, chained_path.split(".")[-1], None):
  87. return final_module
  88. final_module = module
  89. for path in chained_path.split("."):
  90. final_module = getattr(final_module, path, None)
  91. if not final_module:
  92. return None
  93. return final_module
  94. def is_env_variable_true(env_variable: str) -> bool:
  95. """Detect whether `env_variable` has been set to a true value in the environment"""
  96. return os.getenv(env_variable, "false").lower() in ("true", "1", "y", "yes", "on")
  97. def is_env_variable_false(env_variable: str) -> bool:
  98. """Detect whether `env_variable` has been set to a false value in the environment"""
  99. return os.getenv(env_variable, "true").lower() in ("false", "0", "n", "no", "off")
  100. ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
  101. ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
  102. # Try to run a native pytorch job in an environment with TorchXLA installed by setting this value to 0.
  103. USE_TORCH_XLA = os.environ.get("USE_TORCH_XLA", "1").upper()
  104. ACCELERATE_MIN_VERSION = "1.1.0"
  105. BITSANDBYTES_MIN_VERSION = "0.46.1"
  106. SCHEDULEFREE_MIN_VERSION = "1.2.6"
  107. FSDP_MIN_VERSION = "1.12.0"
  108. GGUF_MIN_VERSION = "0.10.0"
  109. XLA_FSDPV2_MIN_VERSION = "2.2.0"
  110. HQQ_MIN_VERSION = "0.2.1"
  111. VPTQ_MIN_VERSION = "0.0.4"
  112. TORCHAO_MIN_VERSION = "0.15.0"
  113. AUTOROUND_MIN_VERSION = "0.5.0"
  114. TRITON_MIN_VERSION = "1.0.0"
  115. KERNELS_MIN_VERSION = "0.10.2"
  116. @lru_cache
  117. def is_torch_available() -> bool:
  118. try:
  119. is_available, torch_version = _is_package_available("torch", return_version=True)
  120. parsed_version = version.parse(torch_version)
  121. if is_available and parsed_version < version.parse("2.4.0"):
  122. logger.warning_once(f"Disabling PyTorch because PyTorch >= 2.4 is required but found {torch_version}")
  123. return is_available and version.parse(torch_version) >= version.parse("2.4.0")
  124. except packaging.version.InvalidVersion:
  125. return False
  126. @lru_cache
  127. def get_torch_version() -> str:
  128. _, torch_version = _is_package_available("torch", return_version=True)
  129. return torch_version
  130. @lru_cache
  131. def is_torch_greater_or_equal(library_version: str, accept_dev: bool = False) -> bool:
  132. """
  133. Accepts a library version and returns True if the current version of the library is greater than or equal to the
  134. given version. If `accept_dev` is True, it will also accept development versions (e.g. 2.7.0.dev20250320 matches
  135. 2.7.0).
  136. """
  137. if not is_torch_available():
  138. return False
  139. if accept_dev:
  140. return version.parse(version.parse(get_torch_version()).base_version) >= version.parse(library_version)
  141. else:
  142. return version.parse(get_torch_version()) >= version.parse(library_version)
  143. @lru_cache
  144. def is_torch_less_or_equal(library_version: str, accept_dev: bool = False) -> bool:
  145. """
  146. Accepts a library version and returns True if the current version of the library is less than or equal to the
  147. given version. If `accept_dev` is True, it will also accept development versions (e.g. 2.7.0.dev20250320 matches
  148. 2.7.0).
  149. """
  150. if not is_torch_available():
  151. return False
  152. if accept_dev:
  153. return version.parse(version.parse(get_torch_version()).base_version) <= version.parse(library_version)
  154. else:
  155. return version.parse(get_torch_version()) <= version.parse(library_version)
  156. @lru_cache
  157. def is_torch_accelerator_available() -> bool:
  158. if is_torch_available():
  159. import torch
  160. return hasattr(torch, "accelerator")
  161. return False
  162. @lru_cache
  163. def is_torch_cuda_available() -> bool:
  164. if is_torch_available():
  165. import torch
  166. return torch.cuda.is_available()
  167. return False
  168. @lru_cache
  169. def is_cuda_platform() -> bool:
  170. if is_torch_available():
  171. import torch
  172. return getattr(torch, "version").cuda is not None
  173. return False
  174. @lru_cache
  175. def get_cuda_runtime_version() -> tuple[int, int]:
  176. """Return the CUDA runtime version as (major, minor).
  177. Unlike ``torch.version.cuda`` which reports the compile-time version,
  178. this queries ``cudaRuntimeGetVersion`` from ``libcudart.so`` to get the
  179. actual runtime version installed on the system.
  180. """
  181. import ctypes
  182. version = ctypes.c_int()
  183. cudart = ctypes.CDLL("libcudart.so")
  184. cudart.cudaRuntimeGetVersion(ctypes.byref(version))
  185. return version.value // 1000, (version.value % 1000) // 10
  186. @lru_cache
  187. def is_rocm_platform() -> bool:
  188. if is_torch_available():
  189. import torch
  190. return getattr(torch, "version").hip is not None
  191. return False
  192. @lru_cache
  193. def is_habana_gaudi1() -> bool:
  194. if not is_torch_hpu_available():
  195. return False
  196. import habana_frameworks.torch.utils.experimental as htexp
  197. # Check if the device is Gaudi1 (vs Gaudi2, Gaudi3)
  198. return htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi
  199. @lru_cache
  200. def is_torch_mps_available(min_version: str | None = None) -> bool:
  201. if is_torch_available():
  202. import torch
  203. backend_available = torch.backends.mps.is_available() and torch.backends.mps.is_built()
  204. if min_version is not None:
  205. flag = version.parse(get_torch_version()) >= version.parse(min_version)
  206. backend_available = backend_available and flag
  207. return backend_available
  208. return False
  209. @lru_cache
  210. def is_torch_npu_available(check_device=False) -> bool:
  211. "Checks if `torch_npu` is installed and potentially if a NPU is in the environment"
  212. if not is_torch_available() or not _is_package_available("torch_npu")[0]:
  213. return False
  214. import torch
  215. import torch_npu # noqa: F401
  216. if check_device:
  217. try:
  218. # Will raise a RuntimeError if no NPU is found
  219. if hasattr(torch, "npu"):
  220. _ = torch.npu.device_count()
  221. return torch.npu.is_available()
  222. return False
  223. except RuntimeError:
  224. return False
  225. return hasattr(torch, "npu") and torch.npu.is_available()
  226. @lru_cache
  227. def is_torch_xpu_available(check_device: bool = False) -> bool:
  228. """
  229. Checks if XPU acceleration is available via stock PyTorch (>=2.6) and
  230. potentially if a XPU is in the environment.
  231. """
  232. if not is_torch_available():
  233. return False
  234. torch_version = version.parse(get_torch_version())
  235. if torch_version.major == 2 and torch_version.minor < 6:
  236. return False
  237. import torch
  238. if check_device:
  239. try:
  240. # Will raise a RuntimeError if no XPU is found
  241. _ = torch.xpu.device_count()
  242. return torch.xpu.is_available()
  243. except RuntimeError:
  244. return False
  245. return hasattr(torch, "xpu") and torch.xpu.is_available()
  246. @lru_cache
  247. def is_torch_mlu_available() -> bool:
  248. """
  249. Checks if `mlu` is available via an `cndev-based` check which won't trigger the drivers and leave mlu
  250. uninitialized.
  251. """
  252. if not is_torch_available() or not _is_package_available("torch_mlu")[0]:
  253. return False
  254. import torch
  255. import torch_mlu # noqa: F401
  256. pytorch_cndev_based_mlu_check_previous_value = os.environ.get("PYTORCH_CNDEV_BASED_MLU_CHECK")
  257. try:
  258. os.environ["PYTORCH_CNDEV_BASED_MLU_CHECK"] = str(1)
  259. available = torch.mlu.is_available() if hasattr(torch, "mlu") else False
  260. finally:
  261. if pytorch_cndev_based_mlu_check_previous_value:
  262. os.environ["PYTORCH_CNDEV_BASED_MLU_CHECK"] = pytorch_cndev_based_mlu_check_previous_value
  263. else:
  264. os.environ.pop("PYTORCH_CNDEV_BASED_MLU_CHECK", None)
  265. return available
  266. @lru_cache
  267. def is_torch_musa_available(check_device=False) -> bool:
  268. "Checks if `torch_musa` is installed and potentially if a MUSA is in the environment"
  269. if not is_torch_available() or not _is_package_available("torch_musa")[0]:
  270. return False
  271. import torch
  272. import torch_musa # noqa: F401
  273. torch_musa_min_version = "0.33.0"
  274. accelerate_available, accelerate_version = _is_package_available("accelerate", return_version=True)
  275. if accelerate_available and version.parse(accelerate_version) < version.parse(torch_musa_min_version):
  276. return False
  277. if check_device:
  278. try:
  279. # Will raise a RuntimeError if no MUSA is found
  280. if hasattr(torch, "musa"):
  281. _ = torch.musa.device_count()
  282. return torch.musa.is_available()
  283. return False
  284. except RuntimeError:
  285. return False
  286. return hasattr(torch, "musa") and torch.musa.is_available()
  287. @lru_cache
  288. def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False) -> bool:
  289. """
  290. Check if `torch_xla` is available. To train a native pytorch job in an environment with torch xla installed, set
  291. the USE_TORCH_XLA to false.
  292. """
  293. assert not (check_is_tpu and check_is_gpu), "The check_is_tpu and check_is_gpu cannot both be true."
  294. torch_xla_available = USE_TORCH_XLA in ENV_VARS_TRUE_VALUES and _is_package_available("torch_xla")[0]
  295. if not torch_xla_available:
  296. return False
  297. import torch_xla
  298. if check_is_gpu:
  299. return torch_xla.runtime.device_type() in ["GPU", "CUDA"]
  300. elif check_is_tpu:
  301. return torch_xla.runtime.device_type() == "TPU"
  302. return True
  303. @lru_cache
  304. def is_torch_hpu_available() -> bool:
  305. "Checks if `torch.hpu` is available and potentially if a HPU is in the environment"
  306. if (
  307. not is_torch_available()
  308. or not _is_package_available("habana_frameworks")[0]
  309. or not _is_package_available("habana_frameworks.torch")[0]
  310. ):
  311. return False
  312. torch_hpu_min_accelerate_version = "1.5.0"
  313. accelerate_available, accelerate_version = _is_package_available("accelerate", return_version=True)
  314. if accelerate_available and version.parse(accelerate_version) < version.parse(torch_hpu_min_accelerate_version):
  315. return False
  316. import torch
  317. if os.environ.get("PT_HPU_LAZY_MODE", "1") == "1":
  318. # import habana_frameworks.torch in case of lazy mode to patch torch with torch.hpu
  319. import habana_frameworks.torch # noqa: F401
  320. if not hasattr(torch, "hpu") or not torch.hpu.is_available():
  321. return False
  322. # We patch torch.gather for int64 tensors to avoid a bug on Gaudi
  323. # Graph compile failed with synStatus 26 [Generic failure]
  324. # This can be removed once bug is fixed but for now we need it.
  325. original_gather = torch.gather
  326. def patched_gather(input: torch.Tensor, dim: int, index: torch.LongTensor) -> torch.Tensor:
  327. if input.dtype == torch.int64 and input.device.type == "hpu":
  328. return original_gather(input.to(torch.int32), dim, index).to(torch.int64)
  329. else:
  330. return original_gather(input, dim, index)
  331. torch.gather = patched_gather
  332. torch.Tensor.gather = patched_gather
  333. original_take_along_dim = torch.take_along_dim
  334. def patched_take_along_dim(input: torch.Tensor, indices: torch.LongTensor, dim: int | None = None) -> torch.Tensor:
  335. if input.dtype == torch.int64 and input.device.type == "hpu":
  336. return original_take_along_dim(input.to(torch.int32), indices, dim).to(torch.int64)
  337. else:
  338. return original_take_along_dim(input, indices, dim)
  339. torch.take_along_dim = patched_take_along_dim
  340. original_cholesky = torch.linalg.cholesky
  341. def safe_cholesky(A, *args, **kwargs):
  342. output = original_cholesky(A, *args, **kwargs)
  343. if torch.isnan(output).any():
  344. jitter_value = 1e-9
  345. diag_jitter = torch.eye(A.size(-1), dtype=A.dtype, device=A.device) * jitter_value
  346. output = original_cholesky(A + diag_jitter, *args, **kwargs)
  347. return output
  348. torch.linalg.cholesky = safe_cholesky
  349. original_scatter = torch.scatter
  350. def patched_scatter(
  351. input: torch.Tensor, dim: int, index: torch.Tensor, src: torch.Tensor, *args, **kwargs
  352. ) -> torch.Tensor:
  353. if input.device.type == "hpu" and input is src:
  354. return original_scatter(input, dim, index, src.clone(), *args, **kwargs)
  355. else:
  356. return original_scatter(input, dim, index, src, *args, **kwargs)
  357. torch.scatter = patched_scatter
  358. torch.Tensor.scatter = patched_scatter
  359. # IlyasMoutawwakil: we patch torch.compile to use the HPU backend by default
  360. # https://github.com/huggingface/transformers/pull/38790#discussion_r2157043944
  361. # This is necessary for cases where torch.compile is used as a decorator (defaulting to inductor)
  362. # https://github.com/huggingface/transformers/blob/af6120b3eb2470b994c21421bb6eaa76576128b0/src/transformers/models/modernbert/modeling_modernbert.py#L204
  363. original_compile = torch.compile
  364. def hpu_backend_compile(*args, **kwargs):
  365. if kwargs.get("backend") not in ["hpu_backend", "eager"]:
  366. logger.warning(
  367. f"Calling torch.compile with backend={kwargs.get('backend')} on a Gaudi device is not supported. "
  368. "We will override the backend with 'hpu_backend' to avoid errors."
  369. )
  370. kwargs["backend"] = "hpu_backend"
  371. return original_compile(*args, **kwargs)
  372. torch.compile = hpu_backend_compile
  373. return True
  374. @lru_cache
  375. def is_torch_neuron_available(check_device: bool = False) -> bool:
  376. import torch
  377. if importlib.util.find_spec("torch_neuronx") is None:
  378. return False
  379. if check_device:
  380. try:
  381. import torch_neuronx # noqa: F401
  382. # Will raise a RuntimeError if no Neuron is found
  383. if hasattr(torch, "neuron"):
  384. _ = torch.neuron.device_count()
  385. return torch.neuron.is_available()
  386. return False
  387. except RuntimeError:
  388. return False
  389. return hasattr(torch, "neuron") and torch.neuron.is_available()
  390. @lru_cache
  391. def is_torch_bf16_gpu_available() -> bool:
  392. if not is_torch_available():
  393. return False
  394. import torch
  395. if torch.cuda.is_available():
  396. return torch.cuda.is_bf16_supported()
  397. if is_torch_xpu_available():
  398. return torch.xpu.is_bf16_supported()
  399. if is_torch_hpu_available():
  400. return True
  401. if is_torch_npu_available() and hasattr(torch, "npu"):
  402. return torch.npu.is_bf16_supported()
  403. if is_torch_mps_available():
  404. # Note: Emulated in software by Metal using fp32 for hardware without native support (like M1/M2)
  405. return torch.backends.mps.is_macos_or_newer(14, 0)
  406. if is_torch_musa_available() and hasattr(torch, "musa"):
  407. return torch.musa.is_bf16_supported()
  408. if is_torch_mlu_available() and hasattr(torch, "mlu"):
  409. return torch.mlu.is_bf16_supported()
  410. if is_torch_neuron_available() and hasattr(torch, "neuron"):
  411. return torch.neuron.is_bf16_supported()
  412. return False
  413. @lru_cache
  414. def is_torch_fp16_available_on_device(device: str) -> bool:
  415. if not is_torch_available():
  416. return False
  417. if is_torch_hpu_available():
  418. if is_habana_gaudi1():
  419. return False
  420. else:
  421. return True
  422. import torch
  423. try:
  424. x = torch.zeros(2, 2, dtype=torch.float16, device=device)
  425. _ = x @ x
  426. # At this moment, let's be strict of the check: check if `LayerNorm` is also supported on device, because many
  427. # models use this layer.
  428. batch, sentence_length, embedding_dim = 3, 4, 5
  429. embedding = torch.randn(batch, sentence_length, embedding_dim, dtype=torch.float16, device=device)
  430. layer_norm = torch.nn.LayerNorm(embedding_dim, dtype=torch.float16, device=device)
  431. _ = layer_norm(embedding)
  432. return True
  433. except Exception:
  434. return False
  435. @lru_cache
  436. def is_torch_bf16_available_on_device(device: str) -> bool:
  437. if not is_torch_available():
  438. return False
  439. import torch
  440. if device == "cuda":
  441. return is_torch_bf16_gpu_available()
  442. if device == "hpu":
  443. return True
  444. try:
  445. x = torch.zeros(2, 2, dtype=torch.bfloat16, device=device)
  446. _ = x @ x
  447. return True
  448. except Exception:
  449. return False
  450. @lru_cache
  451. def is_torch_tf32_available() -> bool:
  452. if not is_torch_available():
  453. return False
  454. import torch
  455. if is_torch_musa_available() and hasattr(torch, "musa"):
  456. device_info = torch.musa.get_device_properties(torch.musa.current_device())
  457. if f"{device_info.major}{device_info.minor}" >= "22":
  458. return True
  459. return False
  460. torch_version = getattr(torch, "version")
  461. if not torch.cuda.is_available() or torch_version.cuda is None:
  462. return False
  463. if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
  464. return False
  465. return True
  466. @lru_cache
  467. def enable_tf32(enable: bool) -> None:
  468. """
  469. Set TF32 mode using the appropriate PyTorch API.
  470. For PyTorch 2.9+, uses the new fp32_precision API.
  471. For older versions, uses the legacy allow_tf32 flags.
  472. Args:
  473. enable: Whether to enable TF32 mode
  474. """
  475. import torch
  476. pytorch_version = version.parse(get_torch_version())
  477. if pytorch_version >= version.parse("2.9.0"):
  478. precision_mode = "tf32" if enable else "ieee"
  479. if hasattr(torch.backends, "fp32_precision"):
  480. torch.backends.fp32_precision = precision_mode
  481. else:
  482. if is_torch_musa_available():
  483. if hasattr(torch.backends, "mudnn"):
  484. torch.backends.mudnn.allow_tf32 = enable
  485. else:
  486. torch.backends.cuda.matmul.allow_tf32 = enable
  487. torch.backends.cudnn.allow_tf32 = enable
  488. @lru_cache
  489. def is_torch_flex_attn_available() -> bool:
  490. return is_torch_available() and version.parse(get_torch_version()) >= version.parse("2.5.0")
  491. @lru_cache
  492. def is_grouped_mm_available() -> bool:
  493. return is_torch_available() and version.parse(get_torch_version()) >= version.parse("2.9.0")
  494. @lru_cache
  495. def is_kenlm_available() -> bool:
  496. return _is_package_available("kenlm")[0]
  497. @lru_cache
  498. def is_kernels_available(MIN_VERSION: str = KERNELS_MIN_VERSION) -> bool:
  499. is_available, kernels_version = _is_package_available("kernels", return_version=True)
  500. return is_available and version.parse(kernels_version) >= version.parse(MIN_VERSION)
  501. @lru_cache
  502. def is_cv2_available() -> bool:
  503. return _is_package_available("cv2")[0]
  504. @lru_cache
  505. def is_yt_dlp_available() -> bool:
  506. return _is_package_available("yt_dlp")[0]
  507. @lru_cache
  508. def is_libcst_available() -> bool:
  509. return _is_package_available("libcst")[0]
  510. @lru_cache
  511. def is_accelerate_available(min_version: str = ACCELERATE_MIN_VERSION) -> bool:
  512. if not is_torch_available():
  513. return False
  514. is_available, accelerate_version = _is_package_available("accelerate", return_version=True)
  515. return is_available and version.parse(accelerate_version) >= version.parse(min_version)
  516. @lru_cache
  517. def is_triton_available(min_version: str = TRITON_MIN_VERSION) -> bool:
  518. is_available, triton_version = _is_package_available("triton", return_version=True)
  519. return is_available and version.parse(triton_version) >= version.parse(min_version)
  520. @lru_cache
  521. def is_hadamard_available() -> bool:
  522. return _is_package_available("fast_hadamard_transform")[0]
  523. @lru_cache
  524. def is_hqq_available(min_version: str = HQQ_MIN_VERSION) -> bool:
  525. is_available, hqq_version = _is_package_available("hqq", return_version=True)
  526. return is_available and version.parse(hqq_version) >= version.parse(min_version)
  527. @lru_cache
  528. def is_pygments_available() -> bool:
  529. return _is_package_available("pygments")[0]
  530. @lru_cache
  531. def is_torchvision_available() -> bool:
  532. return is_vision_available() and is_torch_available() and _is_package_available("torchvision")[0]
  533. @lru_cache
  534. def is_torchvision_v2_available() -> bool:
  535. return is_torchvision_available()
  536. @lru_cache
  537. def is_galore_torch_available() -> bool:
  538. return _is_package_available("galore_torch")[0]
  539. @lru_cache
  540. def is_apollo_torch_available() -> bool:
  541. return _is_package_available("apollo_torch")[0]
  542. @lru_cache
  543. def is_torch_optimi_available() -> bool:
  544. return _is_package_available("optimi")[0]
  545. @lru_cache
  546. def is_lomo_available() -> bool:
  547. return _is_package_available("lomo_optim")[0]
  548. @lru_cache
  549. def is_grokadamw_available() -> bool:
  550. return _is_package_available("grokadamw")[0]
  551. @lru_cache
  552. def is_schedulefree_available(min_version: str = SCHEDULEFREE_MIN_VERSION) -> bool:
  553. is_available, schedulefree_version = _is_package_available("schedulefree", return_version=True)
  554. return is_available and version.parse(schedulefree_version) >= version.parse(min_version)
  555. @lru_cache
  556. def is_pyctcdecode_available() -> bool:
  557. return _is_package_available("pyctcdecode")[0]
  558. @lru_cache
  559. def is_librosa_available() -> bool:
  560. return _is_package_available("librosa")[0]
  561. @lru_cache
  562. def is_multipart_available() -> bool:
  563. return _is_package_available("multipart")[0]
  564. @lru_cache
  565. def is_essentia_available() -> bool:
  566. return _is_package_available("essentia")[0]
  567. @lru_cache
  568. def is_pydantic_available() -> bool:
  569. return _is_package_available("pydantic")[0]
  570. @lru_cache
  571. def is_fastapi_available() -> bool:
  572. return _is_package_available("fastapi")[0]
  573. @lru_cache
  574. def is_uvicorn_available() -> bool:
  575. return _is_package_available("uvicorn")[0]
  576. @lru_cache
  577. def is_openai_available() -> bool:
  578. return _is_package_available("openai")[0]
  579. @lru_cache
  580. def is_serve_available() -> bool:
  581. return is_pydantic_available() and is_fastapi_available() and is_uvicorn_available() and is_openai_available()
  582. @lru_cache
  583. def is_pretty_midi_available() -> bool:
  584. return _is_package_available("pretty_midi")[0]
  585. @lru_cache
  586. def is_mamba_ssm_available() -> bool:
  587. return is_torch_cuda_available() and _is_package_available("mamba_ssm")[0]
  588. @lru_cache
  589. def is_mamba_2_ssm_available() -> bool:
  590. is_available, mamba_ssm_version = _is_package_available("mamba_ssm", return_version=True)
  591. return is_torch_cuda_available() and is_available and version.parse(mamba_ssm_version) >= version.parse("2.0.4")
  592. @lru_cache
  593. def is_flash_linear_attention_available():
  594. is_available, fla_version = _is_package_available("fla", return_version=True)
  595. return is_torch_cuda_available() and is_available and version.parse(fla_version) >= version.parse("0.2.2")
  596. @lru_cache
  597. def is_causal_conv1d_available() -> bool:
  598. return is_torch_cuda_available() and _is_package_available("causal_conv1d")[0]
  599. @lru_cache
  600. def is_xlstm_available() -> bool:
  601. return is_torch_available() and _is_package_available("xlstm")[0]
  602. @lru_cache
  603. def is_mambapy_available() -> bool:
  604. return is_torch_available() and _is_package_available("mambapy")[0]
  605. @lru_cache
  606. def is_peft_available() -> bool:
  607. return _is_package_available("peft")[0]
  608. @lru_cache
  609. def is_bs4_available() -> bool:
  610. return _is_package_available("bs4")[0]
  611. @lru_cache
  612. def is_coloredlogs_available() -> bool:
  613. return _is_package_available("coloredlogs")[0]
  614. @lru_cache
  615. def is_onnx_available() -> bool:
  616. return _is_package_available("onnx")[0]
  617. @lru_cache
  618. def is_flute_available() -> bool:
  619. is_available, flute_version = _is_package_available("flute", return_version=True)
  620. return is_available and version.parse(flute_version) >= version.parse("0.4.1")
  621. @lru_cache
  622. def is_g2p_en_available() -> bool:
  623. return _is_package_available("g2p_en")[0]
  624. @lru_cache
  625. def is_torch_neuroncore_available(check_device=True) -> bool:
  626. return is_torch_xla_available() and _is_package_available("torch_neuronx")[0]
  627. @lru_cache
  628. def is_torch_tensorrt_fx_available() -> bool:
  629. return _is_package_available("torch_tensorrt")[0] and _is_package_available("torch_tensorrt.fx")[0]
  630. @lru_cache
  631. def is_datasets_available() -> bool:
  632. return _is_package_available("datasets")[0]
  633. @lru_cache
  634. def is_detectron2_available() -> bool:
  635. # We need this try/except block because otherwise after uninstalling the library, it stays available for some reason
  636. # i.e. `import detectron2` and `import detectron2.modeling` still work, even though the library is uninstalled
  637. # (the package exists but the objects are not reachable) - so here we explicitly try to import an object from it
  638. try:
  639. from detectron2.modeling import META_ARCH_REGISTRY # noqa
  640. return True
  641. except Exception:
  642. return False
  643. @lru_cache
  644. def is_rjieba_available() -> bool:
  645. return _is_package_available("rjieba")[0]
  646. @lru_cache
  647. def is_psutil_available() -> bool:
  648. return _is_package_available("psutil")[0]
  649. @lru_cache
  650. def is_py3nvml_available() -> bool:
  651. return _is_package_available("py3nvml")[0]
  652. @lru_cache
  653. def is_sacremoses_available() -> bool:
  654. return _is_package_available("sacremoses")[0]
  655. @lru_cache
  656. def is_apex_available() -> bool:
  657. return _is_package_available("apex")[0]
  658. @lru_cache
  659. def is_aqlm_available() -> bool:
  660. return _is_package_available("aqlm")[0]
  661. @lru_cache
  662. def is_vptq_available(min_version: str = VPTQ_MIN_VERSION) -> bool:
  663. is_available, vptq_version = _is_package_available("vptq", return_version=True)
  664. return is_available and version.parse(vptq_version) >= version.parse(min_version)
  665. @lru_cache
  666. def is_av_available() -> bool:
  667. return _is_package_available("av")[0]
  668. @lru_cache
  669. def is_decord_available() -> bool:
  670. return _is_package_available("decord")[0]
  671. @lru_cache
  672. def is_torchcodec_available() -> bool:
  673. return _is_package_available("torchcodec")[0]
  674. @lru_cache
  675. def is_ninja_available() -> bool:
  676. r"""
  677. Code comes from *torch.utils.cpp_extension.is_ninja_available()*. Returns `True` if the
  678. [ninja](https://ninja-build.org/) build system is available on the system, `False` otherwise.
  679. """
  680. try:
  681. subprocess.check_output(["ninja", "--version"])
  682. except Exception:
  683. return False
  684. else:
  685. return True
  686. @lru_cache
  687. def is_bitsandbytes_available(min_version: str = BITSANDBYTES_MIN_VERSION) -> bool:
  688. is_available, bitsandbytes_version = _is_package_available("bitsandbytes", return_version=True)
  689. return is_available and version.parse(bitsandbytes_version) >= version.parse(min_version)
  690. @lru_cache
  691. def is_flash_attn_2_available() -> bool:
  692. is_available, flash_attn_version = _is_package_available("flash_attn", return_version=True)
  693. # FA4 is also distributed under "flash_attn", hence we need to check the naming here
  694. is_available = is_available and "flash-attn" in [
  695. pkg.replace("_", "-") for pkg in PACKAGE_DISTRIBUTION_MAPPING["flash_attn"]
  696. ]
  697. if not is_available or not (is_torch_cuda_available() or is_torch_mlu_available()):
  698. return False
  699. # Only allow versions >= 2.3.3 to avoid very old legacy workarounds that are now 2+ years old
  700. try:
  701. return version.parse(flash_attn_version) >= version.parse("2.3.3")
  702. except packaging.version.InvalidVersion:
  703. return False
  704. @lru_cache
  705. def is_flash_attn_3_available() -> bool:
  706. # Universally available under `flash_attn_interface`
  707. is_available = _is_package_available("flash_attn_interface")[0]
  708. # Resolving and ensuring the proper name of FA3 being associated
  709. is_available = is_available and "flash-attn-3" in [
  710. pkg.replace("_", "-") for pkg in PACKAGE_DISTRIBUTION_MAPPING["flash_attn_interface"]
  711. ]
  712. return is_available and is_torch_cuda_available()
  713. @lru_cache
  714. def is_flash_attn_4_available() -> bool:
  715. is_available = _is_package_available("flash_attn")[0]
  716. # FA2 is also distributed under "flash_attn", hence we need to check the naming here
  717. # NOTE: FA2 seems to distribute the `cute` subdirectory even if only FA2 has been installed
  718. # -> check for the proper (normalized) distribution name
  719. is_available = is_available and "flash-attn-4" in [
  720. pkg.replace("_", "-") for pkg in PACKAGE_DISTRIBUTION_MAPPING["flash_attn"]
  721. ]
  722. return is_available and is_torch_cuda_available()
  723. @lru_cache
  724. def is_flash_attn_greater_or_equal(library_version: str) -> bool:
  725. is_available, flash_attn_version = _is_package_available("flash_attn", return_version=True)
  726. # FA4 is also distributed under "flash_attn", hence we need to check the naming here
  727. is_available = is_available and "flash-attn" in [
  728. pkg.replace("_", "-") for pkg in PACKAGE_DISTRIBUTION_MAPPING["flash_attn"]
  729. ]
  730. if not is_available:
  731. return False
  732. try:
  733. return version.parse(flash_attn_version) >= version.parse(library_version)
  734. except packaging.version.InvalidVersion:
  735. return False
  736. @lru_cache
  737. def is_flash_attn_greater_or_equal_2_10() -> bool:
  738. warnings.warn(
  739. "`is_flash_attn_greater_or_equal_2_10` is deprecated and will be removed in v5.8. "
  740. "Please use `is_flash_attn_greater_or_equal(library_version='2.1.0')` instead if needed.",
  741. FutureWarning,
  742. )
  743. return is_flash_attn_greater_or_equal("2.1.0")
  744. @lru_cache
  745. def is_huggingface_hub_greater_or_equal(library_version: str, accept_dev: bool = False) -> bool:
  746. is_available, hub_version = _is_package_available("huggingface_hub", return_version=True)
  747. if not is_available:
  748. return False
  749. if accept_dev:
  750. return version.parse(version.parse(hub_version).base_version) >= version.parse(library_version)
  751. else:
  752. return version.parse(hub_version) >= version.parse(library_version)
  753. @lru_cache
  754. def is_quanto_greater(library_version: str, accept_dev: bool = False) -> bool:
  755. """
  756. Accepts a library version and returns True if the current version of the library is greater than or equal to the
  757. given version. If `accept_dev` is True, it will also accept development versions (e.g. 2.7.0.dev20250320 matches
  758. 2.7.0).
  759. """
  760. if not is_optimum_quanto_available():
  761. return False
  762. _, quanto_version = _is_package_available("optimum.quanto", return_version=True)
  763. if accept_dev:
  764. return version.parse(version.parse(quanto_version).base_version) > version.parse(library_version)
  765. else:
  766. return version.parse(quanto_version) > version.parse(library_version)
  767. @lru_cache
  768. def is_torchdistx_available():
  769. return _is_package_available("torchdistx")[0]
  770. @lru_cache
  771. def is_faiss_available() -> bool:
  772. return _is_package_available("faiss")[0]
  773. @lru_cache
  774. def is_fouroversix_available() -> bool:
  775. return _is_package_available("fouroversix")
  776. @lru_cache
  777. def is_scipy_available() -> bool:
  778. return _is_package_available("scipy")[0]
  779. @lru_cache
  780. def is_sklearn_available() -> bool:
  781. return _is_package_available("sklearn")[0]
  782. @lru_cache
  783. def is_sentencepiece_available() -> bool:
  784. return _is_package_available("sentencepiece")[0]
  785. @lru_cache
  786. def is_seqio_available() -> bool:
  787. return _is_package_available("seqio")[0]
  788. @lru_cache
  789. def is_gguf_available(min_version: str = GGUF_MIN_VERSION) -> bool:
  790. is_available, gguf_version = _is_package_available("gguf", return_version=True)
  791. return is_available and version.parse(gguf_version) >= version.parse(min_version)
  792. @lru_cache
  793. def is_protobuf_available() -> bool:
  794. return _is_package_available("google")[0] and _is_package_available("google.protobuf")[0]
  795. @lru_cache
  796. def is_fsdp_available(min_version: str = FSDP_MIN_VERSION) -> bool:
  797. return is_torch_available() and version.parse(get_torch_version()) >= version.parse(min_version)
  798. @lru_cache
  799. def is_optimum_available() -> bool:
  800. return _is_package_available("optimum")[0]
  801. @lru_cache
  802. def is_llm_awq_available() -> bool:
  803. return _is_package_available("awq")[0]
  804. @lru_cache
  805. def is_auto_round_available(min_version: str = AUTOROUND_MIN_VERSION) -> bool:
  806. is_available, auto_round_version = _is_package_available("auto_round", return_version=True)
  807. return is_available and version.parse(auto_round_version) >= version.parse(min_version)
  808. @lru_cache
  809. def is_optimum_quanto_available():
  810. return is_optimum_available() and _is_package_available("optimum.quanto")[0]
  811. @lru_cache
  812. def is_quark_available() -> bool:
  813. return _is_package_available("quark")[0]
  814. @lru_cache
  815. def is_fp_quant_available():
  816. is_available, fp_quant_version = _is_package_available("fp_quant", return_version=True)
  817. return is_available and version.parse(fp_quant_version) >= version.parse("0.3.2")
  818. @lru_cache
  819. def is_qutlass_available():
  820. is_available, qutlass_version = _is_package_available("qutlass", return_version=True)
  821. return is_available and version.parse(qutlass_version) >= version.parse("0.2.0")
  822. @lru_cache
  823. def is_compressed_tensors_available() -> bool:
  824. return _is_package_available("compressed_tensors")[0]
  825. @lru_cache
  826. def is_sinq_available() -> bool:
  827. return _is_package_available("sinq")
  828. @lru_cache
  829. def is_gptqmodel_available() -> bool:
  830. return _is_package_available("gptqmodel")[0]
  831. @lru_cache
  832. def is_fbgemm_gpu_available() -> bool:
  833. return _is_package_available("fbgemm_gpu")[0]
  834. @lru_cache
  835. def is_levenshtein_available() -> bool:
  836. return _is_package_available("Levenshtein")[0]
  837. @lru_cache
  838. def is_optimum_neuron_available() -> bool:
  839. return is_optimum_available() and _is_package_available("optimum.neuron")[0]
  840. @lru_cache
  841. def is_tokenizers_available() -> bool:
  842. return _is_package_available("tokenizers")[0]
  843. @lru_cache
  844. def is_vision_available() -> bool:
  845. try:
  846. import PIL.Image # noqa: F401
  847. return True
  848. except ImportError:
  849. return False
  850. @lru_cache
  851. def is_pytesseract_available() -> bool:
  852. return _is_package_available("pytesseract")[0] and is_vision_available()
  853. @lru_cache
  854. def is_pytest_available() -> bool:
  855. return _is_package_available("pytest")[0]
  856. @lru_cache
  857. def is_pytest_order_available() -> bool:
  858. return is_pytest_available() and _is_package_available("pytest_order")[0]
  859. @lru_cache
  860. def is_spacy_available() -> bool:
  861. return _is_package_available("spacy")[0]
  862. @lru_cache
  863. def is_pytorch_quantization_available() -> bool:
  864. return _is_package_available("pytorch_quantization")[0]
  865. @lru_cache
  866. def is_pandas_available() -> bool:
  867. return _is_package_available("pandas")[0]
  868. @lru_cache
  869. def is_soundfile_available() -> bool:
  870. return _is_package_available("soundfile")[0]
  871. @lru_cache
  872. def is_timm_available() -> bool:
  873. return is_vision_available() and is_torch_available() and _is_package_available("timm")[0]
  874. @lru_cache
  875. def is_natten_available() -> bool:
  876. return _is_package_available("natten")[0]
  877. @lru_cache
  878. def is_nltk_available() -> bool:
  879. return _is_package_available("nltk")[0]
  880. @lru_cache
  881. def is_numba_available() -> bool:
  882. is_available = _is_package_available("numba")[0]
  883. if not is_available:
  884. return False
  885. numpy_available, numpy_version = _is_package_available("numpy", return_version=True)
  886. return not numpy_available or version.parse(numpy_version) < version.parse("2.2.0")
  887. @lru_cache
  888. def is_torchaudio_available() -> bool:
  889. return is_torch_available() and _is_package_available("torchaudio")[0]
  890. @lru_cache
  891. def is_torchao_available(min_version: str = TORCHAO_MIN_VERSION) -> bool:
  892. if not is_torch_available():
  893. return False
  894. is_available, torchao_version = _is_package_available("torchao", return_version=True)
  895. return is_available and version.parse(torchao_version) >= version.parse(min_version)
  896. @lru_cache
  897. def is_speech_available() -> bool:
  898. # For now this depends on torchaudio but the exact dependency might evolve in the future.
  899. return is_torchaudio_available()
  900. @lru_cache
  901. def is_spqr_available() -> bool:
  902. return _is_package_available("spqr_quant")[0]
  903. @lru_cache
  904. def is_phonemizer_available() -> bool:
  905. return _is_package_available("phonemizer")[0]
  906. @lru_cache
  907. def is_uroman_available() -> bool:
  908. return _is_package_available("uroman")[0]
  909. @lru_cache
  910. def is_sudachi_available() -> bool:
  911. return _is_package_available("sudachipy")[0]
  912. @lru_cache
  913. def is_sudachi_projection_available() -> bool:
  914. is_available, sudachipy_version = _is_package_available("sudachipy", return_version=True)
  915. return is_available and version.parse(sudachipy_version) >= version.parse("0.6.8")
  916. @lru_cache
  917. def is_jumanpp_available() -> bool:
  918. return _is_package_available("rhoknp")[0] and shutil.which("jumanpp") is not None
  919. @lru_cache
  920. def is_cython_available() -> bool:
  921. return _is_package_available("pyximport")[0]
  922. @lru_cache
  923. def is_jinja_available() -> bool:
  924. return _is_package_available("jinja2")[0]
  925. @lru_cache
  926. def is_jmespath_available() -> bool:
  927. return _is_package_available("jmespath")[0]
  928. @lru_cache
  929. def is_mlx_available() -> bool:
  930. return _is_package_available("mlx")[0]
  931. @lru_cache
  932. def is_num2words_available() -> bool:
  933. return _is_package_available("num2words")[0]
  934. @lru_cache
  935. def is_tiktoken_available(with_blobfile: bool = True) -> bool:
  936. if not _is_package_available("tiktoken")[0]:
  937. return False
  938. return with_blobfile and _is_package_available("blobfile")[0] or True
  939. @lru_cache
  940. def is_liger_kernel_available() -> bool:
  941. is_available, liger_kernel_version = _is_package_available("liger_kernel", return_version=True)
  942. return is_available and version.parse(liger_kernel_version) >= version.parse("0.3.0")
  943. @lru_cache
  944. def is_rich_available() -> bool:
  945. return _is_package_available("rich")[0]
  946. @lru_cache
  947. def is_matplotlib_available() -> bool:
  948. return _is_package_available("matplotlib")[0]
  949. @lru_cache
  950. def is_mistral_common_available() -> bool:
  951. return is_vision_available() and _is_package_available("mistral_common")[0]
  952. @lru_cache
  953. def is_opentelemetry_available() -> bool:
  954. try:
  955. return _is_package_available("opentelemetry")[0] and version.parse(
  956. importlib.metadata.version("opentelemetry-api")
  957. ) >= version.parse("1.30.0")
  958. except Exception as _:
  959. return False
  960. @lru_cache
  961. def is_pynvml_available() -> bool:
  962. return _is_package_available("pynvml")[0]
  963. def check_torch_load_is_safe() -> None:
  964. if not is_torch_greater_or_equal("2.6"):
  965. raise ValueError(
  966. "Due to a serious vulnerability issue in `torch.load`, even with `weights_only=True`, we now require users "
  967. "to upgrade torch to at least v2.6 in order to use the function. This version restriction does not apply "
  968. "when loading files with safetensors."
  969. "\nSee the vulnerability report here https://nvd.nist.gov/vuln/detail/CVE-2025-32434"
  970. )
  971. def torch_only_method(fn: Callable) -> Callable:
  972. def wrapper(*args, **kwargs):
  973. if not is_torch_available():
  974. raise ImportError("You need to install pytorch to use this method or class")
  975. else:
  976. return fn(*args, **kwargs)
  977. return wrapper
  978. def is_torch_deterministic() -> bool:
  979. """
  980. Check whether pytorch uses deterministic algorithms by looking if torch.set_deterministic_debug_mode() is set to 1 or 2"
  981. """
  982. if is_torch_available():
  983. import torch
  984. if torch.get_deterministic_debug_mode() == 0:
  985. return False
  986. else:
  987. return True
  988. return False
  989. @lru_cache
  990. def get_torch_major_and_minor_version() -> str:
  991. torch_version = get_torch_version()
  992. if torch_version == "N/A":
  993. return "N/A"
  994. parsed_version = version.parse(torch_version)
  995. return str(parsed_version.major) + "." + str(parsed_version.minor)
  996. def is_torchdynamo_compiling() -> bool:
  997. # Importing torch._dynamo causes issues with PyTorch profiler (https://github.com/pytorch/pytorch/issues/130622)
  998. # hence rather relying on `torch.compiler.is_compiling()` when possible (torch>=2.3)
  999. try:
  1000. import torch
  1001. if hasattr(torch, "compiler"):
  1002. return torch.compiler.is_compiling()
  1003. return False
  1004. except Exception:
  1005. return False
  1006. def is_torchdynamo_exporting() -> bool:
  1007. try:
  1008. import torch
  1009. if hasattr(torch, "compiler"):
  1010. return torch.compiler.is_exporting()
  1011. return False
  1012. except Exception:
  1013. return False
  1014. def is_torch_fx_proxy(x) -> bool:
  1015. try:
  1016. import torch.fx
  1017. return isinstance(x, torch.fx.Proxy)
  1018. except Exception:
  1019. return False
  1020. def is_fake_tensor(x) -> bool:
  1021. try:
  1022. import torch
  1023. return isinstance(x, getattr(torch, "_subclasses").FakeTensor)
  1024. except Exception:
  1025. return False
  1026. def is_jax_jitting(x):
  1027. """returns True if we are inside of `jax.jit` context, False otherwise.
  1028. When a torch model is being compiled with `jax.jit` using torchax,
  1029. the tensor that goes through the model would be an instance of
  1030. `torchax.tensor.Tensor`, which is a tensor subclass. This tensor has
  1031. a `jax` method to return the inner Jax array
  1032. (https://github.com/google/torchax/blob/13ce870a1d9adb2430333c27bb623469e3aea34e/torchax/tensor.py#L134).
  1033. Here we use ducktyping to detect if the inner jax array is a jax Tracer
  1034. then we are in tracing context. (See more at: https://github.com/jax-ml/jax/discussions/9241)
  1035. Args:
  1036. x: torch.Tensor
  1037. Returns:
  1038. bool: whether we are inside of jax jit tracing.
  1039. """
  1040. if not hasattr(x, "jax"):
  1041. return False
  1042. try:
  1043. import jax
  1044. return isinstance(x.jax(), getattr(jax, "core").Tracer)
  1045. except Exception:
  1046. return False
  1047. def is_jit_tracing() -> bool:
  1048. try:
  1049. import torch
  1050. return torch.jit.is_tracing()
  1051. except Exception:
  1052. return False
  1053. def is_cuda_stream_capturing() -> bool:
  1054. try:
  1055. import torch
  1056. return torch.cuda.is_current_stream_capturing()
  1057. except Exception:
  1058. return False
  1059. def is_tracing(tensor=None) -> bool:
  1060. """Checks whether we are tracing a graph with dynamo (compile or export), torch.jit, torch.fx, jax.jit (with torchax) or
  1061. CUDA stream capturing or FakeTensor"""
  1062. # Note that `is_torchdynamo_compiling` checks both compiling and exporting (the export check is stricter and
  1063. # only checks export)
  1064. _is_tracing = is_torchdynamo_compiling() or is_jit_tracing() or is_cuda_stream_capturing()
  1065. if tensor is not None:
  1066. _is_tracing |= is_torch_fx_proxy(tensor)
  1067. _is_tracing |= is_fake_tensor(tensor)
  1068. _is_tracing |= is_jax_jitting(tensor)
  1069. return _is_tracing
  1070. def torch_compilable_check(cond: Any, msg: str | Callable[[], str], error_type: type[Exception] = ValueError) -> None:
  1071. """
  1072. Combines the functionalities of `torch._check`, `torch._check_with` and `torch._check_tensor_all_with` to provide a
  1073. unified way to perform checks that are compatible with TorchDynamo (torch.compile & torch.export).
  1074. The advantage of using `torch._check(cond, msg, error_type)` over `if cond: raise error_type(msg)` is that the former
  1075. works as a truthfulness hint for TorchDynamo, instead of failing with a data-dependent control flow error during compilation.
  1076. All checks using this method can be disabled in production environments by setting `TRANSFORMERS_DISABLE_TORCH_CHECK=1`.
  1077. Args:
  1078. cond (`bool`, `torch.Tensor` or `Callable[[], bool | torch.Tensor]`): The condition to check.
  1079. msg (`str` or `Callable[[], str]`): The error message to display if the condition is not met.
  1080. error_type (`type[Exception]`, *optional*, defaults to `ValueError`): The type of error to raise if the condition is not met.
  1081. Raises:
  1082. error_type: If the condition is not met.
  1083. """
  1084. if os.getenv("TRANSFORMERS_DISABLE_TORCH_CHECK", "0") == "1":
  1085. return
  1086. import torch
  1087. if not callable(msg):
  1088. # torch._check requires msg to be a callable but we want to keep the API simple for users
  1089. def msg_callable():
  1090. return msg
  1091. else:
  1092. msg_callable = msg
  1093. if callable(cond):
  1094. cond = cond()
  1095. # These checks are also compiler hints for TorchDynamo telling
  1096. # it that the condition is expected to be True during compilation
  1097. if isinstance(cond, torch.Tensor):
  1098. torch._check_tensor_all_with(error_type, cond, msg_callable)
  1099. else:
  1100. torch._check_with(error_type, cond, msg_callable)
  1101. @lru_cache
  1102. def is_in_notebook() -> bool:
  1103. try:
  1104. # Check if we are running inside Marimo
  1105. if "marimo" in sys.modules:
  1106. return True
  1107. # Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
  1108. get_ipython = sys.modules["IPython"].get_ipython
  1109. if "IPKernelApp" not in get_ipython().config:
  1110. raise ImportError("console")
  1111. # Removed the lines to include VSCode
  1112. if "DATABRICKS_RUNTIME_VERSION" in os.environ and os.environ["DATABRICKS_RUNTIME_VERSION"] < "11.0":
  1113. # Databricks Runtime 11.0 and above uses IPython kernel by default so it should be compatible with Jupyter notebook
  1114. # https://docs.microsoft.com/en-us/azure/databricks/notebooks/ipython-kernel
  1115. raise ImportError("databricks")
  1116. return importlib.util.find_spec("IPython") is not None
  1117. except (AttributeError, ImportError, KeyError):
  1118. return False
  1119. def is_sagemaker_dp_enabled() -> bool:
  1120. # Get the sagemaker specific env variable.
  1121. sagemaker_params = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
  1122. try:
  1123. # Parse it and check the field "sagemaker_distributed_dataparallel_enabled".
  1124. sagemaker_params = json.loads(sagemaker_params)
  1125. if not sagemaker_params.get("sagemaker_distributed_dataparallel_enabled", False):
  1126. return False
  1127. except json.JSONDecodeError:
  1128. return False
  1129. # Lastly, check if the `smdistributed` module is present.
  1130. return _is_package_available("smdistributed")[0]
  1131. def is_sagemaker_mp_enabled() -> bool:
  1132. # Get the sagemaker specific mp parameters from smp_options variable.
  1133. smp_options = os.getenv("SM_HP_MP_PARAMETERS", "{}")
  1134. try:
  1135. # Parse it and check the field "partitions" is included, it is required for model parallel.
  1136. smp_options = json.loads(smp_options)
  1137. if "partitions" not in smp_options:
  1138. return False
  1139. except json.JSONDecodeError:
  1140. return False
  1141. # Get the sagemaker specific framework parameters from mpi_options variable.
  1142. mpi_options = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
  1143. try:
  1144. # Parse it and check the field "sagemaker_distributed_dataparallel_enabled".
  1145. mpi_options = json.loads(mpi_options)
  1146. if not mpi_options.get("sagemaker_mpi_enabled", False):
  1147. return False
  1148. except json.JSONDecodeError:
  1149. return False
  1150. # Lastly, check if the `smdistributed` module is present.
  1151. return _is_package_available("smdistributed")[0]
  1152. def is_training_run_on_sagemaker() -> bool:
  1153. return "SAGEMAKER_JOB_NAME" in os.environ
  1154. # docstyle-ignore
  1155. AV_IMPORT_ERROR = """
  1156. {0} requires the PyAv library but it was not found in your environment. You can install it with:
  1157. ```
  1158. pip install av
  1159. ```
  1160. Please note that you may need to restart your runtime after installation.
  1161. """
  1162. # docstyle-ignore
  1163. YT_DLP_IMPORT_ERROR = """
  1164. {0} requires the YT-DLP library but it was not found in your environment. You can install it with:
  1165. ```
  1166. pip install yt-dlp
  1167. ```
  1168. Please note that you may need to restart your runtime after installation.
  1169. """
  1170. DECORD_IMPORT_ERROR = """
  1171. {0} requires the PyAv library but it was not found in your environment. You can install it with:
  1172. ```
  1173. pip install decord
  1174. ```
  1175. Please note that you may need to restart your runtime after installation.
  1176. """
  1177. TORCHCODEC_IMPORT_ERROR = """
  1178. {0} requires the TorchCodec (https://github.com/pytorch/torchcodec) library, but it was not found in your environment. You can install it with:
  1179. ```
  1180. pip install torchcodec
  1181. ```
  1182. Please note that you may need to restart your runtime after installation.
  1183. """
  1184. # docstyle-ignore
  1185. CV2_IMPORT_ERROR = """
  1186. {0} requires the OpenCV library but it was not found in your environment. You can install it with:
  1187. ```
  1188. pip install opencv-python
  1189. ```
  1190. Please note that you may need to restart your runtime after installation.
  1191. """
  1192. # docstyle-ignore
  1193. DATASETS_IMPORT_ERROR = """
  1194. {0} requires the 🤗 Datasets library but it was not found in your environment. You can install it with:
  1195. ```
  1196. pip install datasets
  1197. ```
  1198. In a notebook or a colab, you can install it by executing a cell with
  1199. ```
  1200. !pip install datasets
  1201. ```
  1202. then restarting your kernel.
  1203. Note that if you have a local folder named `datasets` or a local python file named `datasets.py` in your current
  1204. working directory, python may try to import this instead of the 🤗 Datasets library. You should rename this folder or
  1205. that python file if that's the case. Please note that you may need to restart your runtime after installation.
  1206. """
  1207. # docstyle-ignore
  1208. TOKENIZERS_IMPORT_ERROR = """
  1209. {0} requires the 🤗 Tokenizers library but it was not found in your environment. You can install it with:
  1210. ```
  1211. pip install tokenizers
  1212. ```
  1213. In a notebook or a colab, you can install it by executing a cell with
  1214. ```
  1215. !pip install tokenizers
  1216. ```
  1217. Please note that you may need to restart your runtime after installation.
  1218. """
  1219. # docstyle-ignore
  1220. SENTENCEPIECE_IMPORT_ERROR = """
  1221. {0} requires the SentencePiece library but it was not found in your environment. Check out the instructions on the
  1222. installation page of its repo: https://github.com/google/sentencepiece#installation and follow the ones
  1223. that match your environment. Please note that you may need to restart your runtime after installation.
  1224. """
  1225. # docstyle-ignore
  1226. PROTOBUF_IMPORT_ERROR = """
  1227. {0} requires the protobuf library but it was not found in your environment. Check out the instructions on the
  1228. installation page of its repo: https://github.com/protocolbuffers/protobuf/tree/master/python#installation and follow the ones
  1229. that match your environment. Please note that you may need to restart your runtime after installation.
  1230. """
  1231. # docstyle-ignore
  1232. FAISS_IMPORT_ERROR = """
  1233. {0} requires the faiss library but it was not found in your environment. Check out the instructions on the
  1234. installation page of its repo: https://github.com/facebookresearch/faiss/blob/master/INSTALL.md and follow the ones
  1235. that match your environment. Please note that you may need to restart your runtime after installation.
  1236. """
  1237. # docstyle-ignore
  1238. PYTORCH_IMPORT_ERROR = """
  1239. {0} requires the PyTorch library but it was not found in your environment. Check out the instructions on the
  1240. installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
  1241. Please note that you may need to restart your runtime after installation.
  1242. """
  1243. # docstyle-ignore
  1244. TORCHVISION_IMPORT_ERROR = """
  1245. {0} requires the Torchvision library but it was not found in your environment. Check out the instructions on the
  1246. installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
  1247. Please note that you may need to restart your runtime after installation.
  1248. """
  1249. # docstyle-ignore
  1250. BS4_IMPORT_ERROR = """
  1251. {0} requires the Beautiful Soup library but it was not found in your environment. You can install it with pip:
  1252. `pip install beautifulsoup4`. Please note that you may need to restart your runtime after installation.
  1253. """
  1254. # docstyle-ignore
  1255. SKLEARN_IMPORT_ERROR = """
  1256. {0} requires the scikit-learn library but it was not found in your environment. You can install it with:
  1257. ```
  1258. pip install -U scikit-learn
  1259. ```
  1260. In a notebook or a colab, you can install it by executing a cell with
  1261. ```
  1262. !pip install -U scikit-learn
  1263. ```
  1264. Please note that you may need to restart your runtime after installation.
  1265. """
  1266. # docstyle-ignore
  1267. DETECTRON2_IMPORT_ERROR = """
  1268. {0} requires the detectron2 library but it was not found in your environment. Check out the instructions on the
  1269. installation page: https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md and follow the ones
  1270. that match your environment. Please note that you may need to restart your runtime after installation.
  1271. """
  1272. LEVENSHTEIN_IMPORT_ERROR = """
  1273. {0} requires the python-Levenshtein library but it was not found in your environment. You can install it with pip: `pip
  1274. install python-Levenshtein`. Please note that you may need to restart your runtime after installation.
  1275. """
  1276. # docstyle-ignore
  1277. G2P_EN_IMPORT_ERROR = """
  1278. {0} requires the g2p-en library but it was not found in your environment. You can install it with pip:
  1279. `pip install g2p-en`. Please note that you may need to restart your runtime after installation.
  1280. """
  1281. # docstyle-ignore
  1282. PYTORCH_QUANTIZATION_IMPORT_ERROR = """
  1283. {0} requires the pytorch-quantization library but it was not found in your environment. You can install it with pip:
  1284. `pip install pytorch-quantization --extra-index-url https://pypi.ngc.nvidia.com`
  1285. Please note that you may need to restart your runtime after installation.
  1286. """
  1287. # docstyle-ignore
  1288. TORCHAUDIO_IMPORT_ERROR = """
  1289. {0} requires the torchaudio library but it was not found in your environment. Please install it and restart your
  1290. runtime.
  1291. """
  1292. # docstyle-ignore
  1293. PANDAS_IMPORT_ERROR = """
  1294. {0} requires the pandas library but it was not found in your environment. You can install it with pip as
  1295. explained here: https://pandas.pydata.org/pandas-docs/stable/getting_started/install.html.
  1296. Please note that you may need to restart your runtime after installation.
  1297. """
  1298. # docstyle-ignore
  1299. PHONEMIZER_IMPORT_ERROR = """
  1300. {0} requires the phonemizer library but it was not found in your environment. You can install it with pip:
  1301. `pip install phonemizer`. Please note that you may need to restart your runtime after installation.
  1302. """
  1303. # docstyle-ignore
  1304. UROMAN_IMPORT_ERROR = """
  1305. {0} requires the uroman library but it was not found in your environment. You can install it with pip:
  1306. `pip install uroman`. Please note that you may need to restart your runtime after installation.
  1307. """
  1308. # docstyle-ignore
  1309. SACREMOSES_IMPORT_ERROR = """
  1310. {0} requires the sacremoses library but it was not found in your environment. You can install it with pip:
  1311. `pip install sacremoses`. Please note that you may need to restart your runtime after installation.
  1312. """
  1313. # docstyle-ignore
  1314. SCIPY_IMPORT_ERROR = """
  1315. {0} requires the scipy library but it was not found in your environment. You can install it with pip:
  1316. `pip install scipy`. Please note that you may need to restart your runtime after installation.
  1317. """
  1318. # docstyle-ignore
  1319. SPEECH_IMPORT_ERROR = """
  1320. {0} requires the torchaudio library but it was not found in your environment. You can install it with pip:
  1321. `pip install torchaudio`. Please note that you may need to restart your runtime after installation.
  1322. """
  1323. # docstyle-ignore
  1324. TIMM_IMPORT_ERROR = """
  1325. {0} requires the timm library but it was not found in your environment. You can install it with pip:
  1326. `pip install timm`. Please note that you may need to restart your runtime after installation.
  1327. """
  1328. # docstyle-ignore
  1329. NATTEN_IMPORT_ERROR = """
  1330. {0} requires the natten library but it was not found in your environment. You can install it by referring to:
  1331. shi-labs.com/natten . You can also install it with pip (may take longer to build):
  1332. `pip install natten`. Please note that you may need to restart your runtime after installation.
  1333. """
  1334. NUMEXPR_IMPORT_ERROR = """
  1335. {0} requires the numexpr library but it was not found in your environment. You can install it by referring to:
  1336. https://numexpr.readthedocs.io/en/latest/index.html.
  1337. """
  1338. # docstyle-ignore
  1339. NLTK_IMPORT_ERROR = """
  1340. {0} requires the NLTK library but it was not found in your environment. You can install it by referring to:
  1341. https://www.nltk.org/install.html. Please note that you may need to restart your runtime after installation.
  1342. """
  1343. # docstyle-ignore
  1344. VISION_IMPORT_ERROR = """
  1345. {0} requires the PIL library but it was not found in your environment. You can install it with pip:
  1346. `pip install pillow`. Please note that you may need to restart your runtime after installation.
  1347. """
  1348. # docstyle-ignore
  1349. PYDANTIC_IMPORT_ERROR = """
  1350. {0} requires the pydantic library but it was not found in your environment. You can install it with pip:
  1351. `pip install pydantic`. Please note that you may need to restart your runtime after installation.
  1352. """
  1353. # docstyle-ignore
  1354. FASTAPI_IMPORT_ERROR = """
  1355. {0} requires the fastapi library but it was not found in your environment. You can install it with pip:
  1356. `pip install fastapi`. Please note that you may need to restart your runtime after installation.
  1357. """
  1358. # docstyle-ignore
  1359. UVICORN_IMPORT_ERROR = """
  1360. {0} requires the uvicorn library but it was not found in your environment. You can install it with pip:
  1361. `pip install uvicorn`. Please note that you may need to restart your runtime after installation.
  1362. """
  1363. # docstyle-ignore
  1364. OPENAI_IMPORT_ERROR = """
  1365. {0} requires the openai library but it was not found in your environment. You can install it with pip:
  1366. `pip install openai`. Please note that you may need to restart your runtime after installation.
  1367. """
  1368. # docstyle-ignore
  1369. PYTESSERACT_IMPORT_ERROR = """
  1370. {0} requires the PyTesseract library but it was not found in your environment. You can install it with pip:
  1371. `pip install pytesseract`. Please note that you may need to restart your runtime after installation.
  1372. """
  1373. # docstyle-ignore
  1374. PYCTCDECODE_IMPORT_ERROR = """
  1375. {0} requires the pyctcdecode library but it was not found in your environment. You can install it with pip:
  1376. `pip install pyctcdecode`. Please note that you may need to restart your runtime after installation.
  1377. """
  1378. # docstyle-ignore
  1379. ACCELERATE_IMPORT_ERROR = """
  1380. {0} requires the accelerate library >= {ACCELERATE_MIN_VERSION} it was not found in your environment.
  1381. You can install or update it with pip: `pip install --upgrade accelerate`. Please note that you may need to restart your
  1382. runtime after installation.
  1383. """
  1384. # docstyle-ignore
  1385. ESSENTIA_IMPORT_ERROR = """
  1386. {0} requires essentia library. But that was not found in your environment. You can install them with pip:
  1387. `pip install essentia==2.1b6.dev1034`
  1388. Please note that you may need to restart your runtime after installation.
  1389. """
  1390. # docstyle-ignore
  1391. LIBROSA_IMPORT_ERROR = """
  1392. {0} requires the librosa library. But that was not found in your environment. You can install them with pip:
  1393. `pip install librosa`
  1394. Please note that you may need to restart your runtime after installation.
  1395. """
  1396. # docstyle-ignore
  1397. PRETTY_MIDI_IMPORT_ERROR = """
  1398. {0} requires the pretty_midi library. But that was not found in your environment. You can install them with pip:
  1399. `pip install pretty_midi`
  1400. Please note that you may need to restart your runtime after installation.
  1401. """
  1402. CYTHON_IMPORT_ERROR = """
  1403. {0} requires the Cython library but it was not found in your environment. You can install it with pip: `pip install
  1404. Cython`. Please note that you may need to restart your runtime after installation.
  1405. """
  1406. RJIEBA_IMPORT_ERROR = """
  1407. {0} requires the rjieba library but it was not found in your environment. You can install it with pip: `pip install
  1408. rjieba`. Please note that you may need to restart your runtime after installation.
  1409. """
  1410. PEFT_IMPORT_ERROR = """
  1411. {0} requires the peft library but it was not found in your environment. You can install it with pip: `pip install
  1412. peft`. Please note that you may need to restart your runtime after installation.
  1413. """
  1414. JINJA_IMPORT_ERROR = """
  1415. {0} requires the jinja library but it was not found in your environment. You can install it with pip: `pip install
  1416. jinja2`. Please note that you may need to restart your runtime after installation.
  1417. """
  1418. RICH_IMPORT_ERROR = """
  1419. {0} requires the rich library but it was not found in your environment. You can install it with pip: `pip install
  1420. rich`. Please note that you may need to restart your runtime after installation.
  1421. """
  1422. MISTRAL_COMMON_IMPORT_ERROR = """
  1423. {0} requires the mistral-common library but it was not found in your environment. You can install it with pip: `pip install mistral-common`. Please note that you may need to restart your runtime after installation.
  1424. """
  1425. BACKENDS_MAPPING = OrderedDict(
  1426. [
  1427. ("av", (is_av_available, AV_IMPORT_ERROR)),
  1428. ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
  1429. ("cv2", (is_cv2_available, CV2_IMPORT_ERROR)),
  1430. ("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)),
  1431. ("decord", (is_decord_available, DECORD_IMPORT_ERROR)),
  1432. ("detectron2", (is_detectron2_available, DETECTRON2_IMPORT_ERROR)),
  1433. ("essentia", (is_essentia_available, ESSENTIA_IMPORT_ERROR)),
  1434. ("faiss", (is_faiss_available, FAISS_IMPORT_ERROR)),
  1435. ("g2p_en", (is_g2p_en_available, G2P_EN_IMPORT_ERROR)),
  1436. ("pandas", (is_pandas_available, PANDAS_IMPORT_ERROR)),
  1437. ("phonemizer", (is_phonemizer_available, PHONEMIZER_IMPORT_ERROR)),
  1438. ("uroman", (is_uroman_available, UROMAN_IMPORT_ERROR)),
  1439. ("pretty_midi", (is_pretty_midi_available, PRETTY_MIDI_IMPORT_ERROR)),
  1440. ("levenshtein", (is_levenshtein_available, LEVENSHTEIN_IMPORT_ERROR)),
  1441. ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)),
  1442. ("protobuf", (is_protobuf_available, PROTOBUF_IMPORT_ERROR)),
  1443. ("pyctcdecode", (is_pyctcdecode_available, PYCTCDECODE_IMPORT_ERROR)),
  1444. ("pytesseract", (is_pytesseract_available, PYTESSERACT_IMPORT_ERROR)),
  1445. ("sacremoses", (is_sacremoses_available, SACREMOSES_IMPORT_ERROR)),
  1446. ("pytorch_quantization", (is_pytorch_quantization_available, PYTORCH_QUANTIZATION_IMPORT_ERROR)),
  1447. ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
  1448. ("sklearn", (is_sklearn_available, SKLEARN_IMPORT_ERROR)),
  1449. ("speech", (is_speech_available, SPEECH_IMPORT_ERROR)),
  1450. ("timm", (is_timm_available, TIMM_IMPORT_ERROR)),
  1451. ("torchaudio", (is_torchaudio_available, TORCHAUDIO_IMPORT_ERROR)),
  1452. ("natten", (is_natten_available, NATTEN_IMPORT_ERROR)),
  1453. ("nltk", (is_nltk_available, NLTK_IMPORT_ERROR)),
  1454. ("tokenizers", (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)),
  1455. ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
  1456. ("torchvision", (is_torchvision_available, TORCHVISION_IMPORT_ERROR)),
  1457. ("torchcodec", (is_torchcodec_available, TORCHCODEC_IMPORT_ERROR)),
  1458. ("vision", (is_vision_available, VISION_IMPORT_ERROR)),
  1459. ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
  1460. ("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)),
  1461. ("cython", (is_cython_available, CYTHON_IMPORT_ERROR)),
  1462. ("rjieba", (is_rjieba_available, RJIEBA_IMPORT_ERROR)),
  1463. ("peft", (is_peft_available, PEFT_IMPORT_ERROR)),
  1464. ("jinja", (is_jinja_available, JINJA_IMPORT_ERROR)),
  1465. ("yt_dlp", (is_yt_dlp_available, YT_DLP_IMPORT_ERROR)),
  1466. ("rich", (is_rich_available, RICH_IMPORT_ERROR)),
  1467. ("pydantic", (is_pydantic_available, PYDANTIC_IMPORT_ERROR)),
  1468. ("fastapi", (is_fastapi_available, FASTAPI_IMPORT_ERROR)),
  1469. ("uvicorn", (is_uvicorn_available, UVICORN_IMPORT_ERROR)),
  1470. ("openai", (is_openai_available, OPENAI_IMPORT_ERROR)),
  1471. ("mistral-common", (is_mistral_common_available, MISTRAL_COMMON_IMPORT_ERROR)),
  1472. ]
  1473. )
  1474. def requires_backends(obj, backends):
  1475. """
  1476. Method that automatically raises in case the specified backends are not available. It is often used during class
  1477. initialization to ensure the required dependencies are installed:
  1478. ```py
  1479. requires_backends(self, ["torch"])
  1480. ```
  1481. The backends should be defined in the `BACKEND_MAPPING` defined in `transformers.utils.import_utils`.
  1482. Args:
  1483. obj: object to be checked
  1484. backends: list or tuple of backends to check.
  1485. """
  1486. if not isinstance(backends, list | tuple):
  1487. backends = [backends]
  1488. name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
  1489. failed = []
  1490. for backend in backends:
  1491. if isinstance(backend, Backend):
  1492. available, msg = backend.is_satisfied, backend.error_message
  1493. else:
  1494. available, msg = BACKENDS_MAPPING[backend]
  1495. if not available():
  1496. failed.append(msg.format(name))
  1497. if failed:
  1498. raise ImportError("".join(failed))
  1499. class DummyObject(type):
  1500. """
  1501. Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
  1502. `requires_backend` each time a user tries to access any method of that class.
  1503. """
  1504. is_dummy = True
  1505. def __getattribute__(cls, key):
  1506. if (key.startswith("_") and key != "_from_config") or key == "is_dummy" or key == "mro" or key == "call":
  1507. return super().__getattribute__(key)
  1508. requires_backends(cls, cls._backends)
  1509. BACKENDS_T = frozenset[str]
  1510. IMPORT_STRUCTURE_T = dict[BACKENDS_T, dict[str, set[str]]]
  1511. class _LazyModule(ModuleType):
  1512. """
  1513. Module class that surfaces all objects but only performs associated imports when the objects are requested.
  1514. """
  1515. # Very heavily inspired by optuna.integration._IntegrationModule
  1516. # https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
  1517. def __init__(
  1518. self,
  1519. name: str,
  1520. module_file: str,
  1521. import_structure: IMPORT_STRUCTURE_T,
  1522. module_spec: importlib.machinery.ModuleSpec | None = None,
  1523. extra_objects: dict[str, object] | None = None,
  1524. explicit_import_shortcut: dict[str, list[str]] | None = None,
  1525. ):
  1526. super().__init__(name)
  1527. self._object_missing_backend = {}
  1528. self._explicit_import_shortcut = explicit_import_shortcut if explicit_import_shortcut else {}
  1529. if any(isinstance(key, frozenset) for key in import_structure):
  1530. self._modules = set()
  1531. self._class_to_module = {}
  1532. self.__all__ = []
  1533. _import_structure = {}
  1534. for backends, module in import_structure.items():
  1535. missing_backends = []
  1536. # This ensures that if a module is importable, then all other keys of the module are importable.
  1537. # As an example, in module.keys() we might have the following:
  1538. #
  1539. # dict_keys(['models.nllb_moe.configuration_nllb_moe', 'models.sew_d.configuration_sew_d'])
  1540. #
  1541. # with this, we don't only want to be able to import these explicitly, we want to be able to import
  1542. # every intermediate module as well. Therefore, this is what is returned:
  1543. #
  1544. # {
  1545. # 'models.nllb_moe.configuration_nllb_moe',
  1546. # 'models.sew_d.configuration_sew_d',
  1547. # 'models',
  1548. # 'models.sew_d', 'models.nllb_moe'
  1549. # }
  1550. module_keys = set(
  1551. chain(*[[k.rsplit(".", i)[0] for i in range(k.count(".") + 1)] for k in list(module.keys())])
  1552. )
  1553. for backend in backends:
  1554. if backend in BACKENDS_MAPPING:
  1555. callable, _ = BACKENDS_MAPPING[backend]
  1556. else:
  1557. if any(key in backend for key in ["=", "<", ">"]):
  1558. backend = Backend(backend)
  1559. callable = backend.is_satisfied
  1560. else:
  1561. raise ValueError(
  1562. f"Backend should be defined in the BACKENDS_MAPPING. Offending backend: {backend}"
  1563. )
  1564. try:
  1565. if not callable():
  1566. missing_backends.append(backend)
  1567. except (ModuleNotFoundError, RuntimeError):
  1568. missing_backends.append(backend)
  1569. self._modules = self._modules.union(module_keys)
  1570. for key, values in module.items():
  1571. if missing_backends:
  1572. self._object_missing_backend[key] = missing_backends
  1573. for value in values:
  1574. self._class_to_module[value] = key
  1575. if missing_backends:
  1576. self._object_missing_backend[value] = missing_backends
  1577. _import_structure.setdefault(key, []).extend(values)
  1578. # Needed for autocompletion in an IDE
  1579. self.__all__.extend(module_keys | set(chain(*module.values())))
  1580. self.__file__ = module_file
  1581. self.__spec__ = module_spec
  1582. self.__path__ = [os.path.dirname(module_file)]
  1583. self._objects = {} if extra_objects is None else extra_objects
  1584. self._name = name
  1585. self._import_structure = _import_structure
  1586. # This can be removed once every exportable object has a `require()` require.
  1587. else:
  1588. self._modules = set(import_structure.keys())
  1589. self._class_to_module = {}
  1590. for key, values in import_structure.items():
  1591. for value in values:
  1592. self._class_to_module[value] = key
  1593. # Needed for autocompletion in an IDE
  1594. self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values()))
  1595. self.__file__ = module_file
  1596. self.__spec__ = module_spec
  1597. self.__path__ = [os.path.dirname(module_file)]
  1598. self._objects = {} if extra_objects is None else extra_objects
  1599. self._name = name
  1600. self._import_structure = import_structure
  1601. # Needed for autocompletion in an IDE
  1602. def __dir__(self):
  1603. result = list(super().__dir__())
  1604. # The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether
  1605. # they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir.
  1606. for attr in self.__all__:
  1607. if attr not in result:
  1608. result.append(attr)
  1609. return result
  1610. def __getattr__(self, name: str) -> Any:
  1611. if name in self._objects:
  1612. return self._objects[name]
  1613. if name in self._object_missing_backend:
  1614. missing_backends = self._object_missing_backend[name]
  1615. # Backward-compat fallback: before the image processor refactoring, the base
  1616. # `<Model>ImageProcessor` name referred to the PIL/slow backend. After the refactoring
  1617. # it refers to the TorchvisionBackend (which requires torchvision). So if torchvision
  1618. # is not installed, transparently fall back to `<Model>ImageProcessorPil` and warn once.
  1619. if "torchvision" in missing_backends and name.endswith("ImageProcessor"):
  1620. pil_name = f"{name}Pil"
  1621. if pil_name in self._class_to_module and pil_name not in self._object_missing_backend:
  1622. try:
  1623. pil_module = self._get_module(self._class_to_module[pil_name])
  1624. pil_value = getattr(pil_module, pil_name)
  1625. logger.warning_once(
  1626. f"`{name}` requires torchvision (not installed); falling back to `{pil_name}` "
  1627. f"for backward compatibility. Install torchvision to use the default backend, "
  1628. f"or import `{pil_name}` directly to silence this warning."
  1629. )
  1630. setattr(self, name, pil_value)
  1631. return pil_value
  1632. except Exception as e:
  1633. logger.debug(f"Could not load PIL fallback {pil_name}: {e}")
  1634. class Placeholder(metaclass=DummyObject):
  1635. _backends = missing_backends
  1636. def __init__(self, *args, **kwargs):
  1637. requires_backends(self, missing_backends)
  1638. def call(self, *args, **kwargs):
  1639. pass
  1640. Placeholder.__name__ = name
  1641. if name not in self._class_to_module:
  1642. module_name = f"transformers.{name}"
  1643. else:
  1644. module_name = self._class_to_module[name]
  1645. if not module_name.startswith("transformers."):
  1646. module_name = f"transformers.{module_name}"
  1647. Placeholder.__module__ = module_name
  1648. value = Placeholder
  1649. elif name in self._class_to_module:
  1650. try:
  1651. module = self._get_module(self._class_to_module[name])
  1652. value = getattr(module, name)
  1653. except (ModuleNotFoundError, RuntimeError, AttributeError) as e:
  1654. # V5: If trying to import a *TokenizerFast symbol, transparently fall back to the
  1655. # non-Fast symbol from the same module when available. This lets us keep only one
  1656. # backend tokenizer class while preserving legacy public names.
  1657. if name.endswith("TokenizerFast"):
  1658. fallback_name = name[:-4]
  1659. # Prefer importing the module that declares the fallback symbol if known
  1660. try:
  1661. if fallback_name in self._class_to_module:
  1662. fb_module = self._get_module(self._class_to_module[fallback_name])
  1663. fallback_value = getattr(fb_module, fallback_name)
  1664. else:
  1665. module = self._get_module(self._class_to_module[name])
  1666. fallback_value = getattr(module, fallback_name)
  1667. setattr(self, fallback_name, fallback_value)
  1668. value = fallback_value
  1669. except Exception:
  1670. # If we can't find the fallback here, try converter logic as a last resort
  1671. # before giving up
  1672. value = None
  1673. # Try converter mapping for Fast tokenizers that don't exist
  1674. if value is None and name.endswith("TokenizerFast"):
  1675. lookup_name = name[:-4]
  1676. try:
  1677. from ..convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS
  1678. if lookup_name in SLOW_TO_FAST_CONVERTERS:
  1679. converter_class = SLOW_TO_FAST_CONVERTERS[lookup_name]
  1680. converter_base_name = converter_class.__name__.replace("Converter", "")
  1681. preferred_tokenizer_name = f"{converter_base_name}Tokenizer"
  1682. candidate_names = [preferred_tokenizer_name]
  1683. for tokenizer_name, tokenizer_converter in SLOW_TO_FAST_CONVERTERS.items():
  1684. if tokenizer_converter is converter_class and tokenizer_name != lookup_name:
  1685. if tokenizer_name not in candidate_names:
  1686. candidate_names.append(tokenizer_name)
  1687. # Try to import the preferred candidate directly
  1688. import importlib
  1689. for candidate_name in candidate_names:
  1690. base_tokenizer_class = None
  1691. # Try to derive module path from tokenizer name (e.g., "AlbertTokenizer" -> "albert")
  1692. # Remove "Tokenizer" suffix and convert to lowercase
  1693. if candidate_name.endswith("Tokenizer"):
  1694. model_name = candidate_name[:-10].lower() # Remove "Tokenizer"
  1695. module_path = f"transformers.models.{model_name}.tokenization_{model_name}"
  1696. try:
  1697. module = importlib.import_module(module_path)
  1698. base_tokenizer_class = getattr(module, candidate_name)
  1699. except Exception:
  1700. logger.debug(f"{module_path} does not have {candidate_name} defined.")
  1701. # Fallback: try via _class_to_module
  1702. if base_tokenizer_class is None and candidate_name in self._class_to_module:
  1703. try:
  1704. alias_module_name = self._class_to_module[candidate_name]
  1705. alias_module = self._get_module(alias_module_name)
  1706. base_tokenizer_class = getattr(alias_module, candidate_name)
  1707. except Exception:
  1708. logger.debug(
  1709. f"{alias_module_name} does not have {candidate_name} defined"
  1710. )
  1711. # If we still don't have base_tokenizer_class, skip this candidate
  1712. if base_tokenizer_class is None:
  1713. logger.debug(f"skipping candidate {candidate_name}")
  1714. continue
  1715. # If we got here, we have base_tokenizer_class
  1716. value = base_tokenizer_class
  1717. setattr(self, candidate_name, base_tokenizer_class)
  1718. if lookup_name != candidate_name:
  1719. setattr(self, lookup_name, value)
  1720. setattr(self, name, value)
  1721. break
  1722. except Exception as e:
  1723. logger.debug(f"Could not create tokenizer alias: {e}")
  1724. if value is None:
  1725. raise ModuleNotFoundError(
  1726. f"Could not import module '{name}'. Are this object's requirements defined correctly?"
  1727. ) from e
  1728. else:
  1729. raise ModuleNotFoundError(
  1730. f"Could not import module '{name}'. Are this object's requirements defined correctly?"
  1731. ) from e
  1732. elif name in self._modules:
  1733. try:
  1734. value = self._get_module(name)
  1735. except (ModuleNotFoundError, RuntimeError) as e:
  1736. raise ModuleNotFoundError(
  1737. f"Could not import module '{name}'. Are this object's requirements defined correctly?"
  1738. ) from e
  1739. else:
  1740. # V5: If a *TokenizerFast symbol is requested but not present in the import structure,
  1741. # try to resolve to the corresponding non-Fast symbol's module if available.
  1742. if name.endswith("TokenizerFast"):
  1743. fallback_name = name[:-4]
  1744. if fallback_name in self._class_to_module:
  1745. try:
  1746. fb_module = self._get_module(self._class_to_module[fallback_name])
  1747. value = getattr(fb_module, fallback_name)
  1748. setattr(self, fallback_name, value)
  1749. setattr(self, name, value)
  1750. return value
  1751. except Exception as e:
  1752. logger.debug(f"Could not load fallback {fallback_name}: {e}")
  1753. # V5: Handle *ImageProcessorFast backward compatibility
  1754. # Similar to TokenizerFast, but for image processors
  1755. if name.endswith("ImageProcessorFast"):
  1756. fallback_name = name[:-4] # Remove "Fast"
  1757. if fallback_name in self._class_to_module:
  1758. logger.warning_once(
  1759. f"`{name}` is deprecated. The `Fast` suffix for image processors has been removed; "
  1760. f"use `{fallback_name}` instead."
  1761. )
  1762. if fallback_name in self._object_missing_backend:
  1763. # The Fast alias has no entry in the import structure, so `requires_backends` on
  1764. # the real class never runs. Handle the missing backend explicitly here, otherwise
  1765. # `_get_module` swallows the ImportError and the caller gets an AttributeError.
  1766. # Do not fall through to the PIL fallback since a legacy "Fast" image processor was explicitly requested.
  1767. missing_backends = self._object_missing_backend[fallback_name]
  1768. class Placeholder(metaclass=DummyObject):
  1769. _backends = missing_backends
  1770. def __init__(self, *args, **kwargs):
  1771. requires_backends(self, missing_backends)
  1772. def call(self, *args, **kwargs):
  1773. pass
  1774. Placeholder.__name__ = fallback_name
  1775. module_name = self._class_to_module[fallback_name]
  1776. Placeholder.__module__ = (
  1777. module_name if module_name.startswith("transformers.") else f"transformers.{module_name}"
  1778. )
  1779. setattr(self, name, Placeholder)
  1780. return Placeholder
  1781. try:
  1782. fb_module = self._get_module(self._class_to_module[fallback_name])
  1783. value = getattr(fb_module, fallback_name)
  1784. setattr(self, fallback_name, value)
  1785. setattr(self, name, value)
  1786. return value
  1787. except Exception as e:
  1788. logger.debug(f"Could not load fallback {fallback_name}: {e}")
  1789. # V5: If a tokenizer class doesn't exist, check if it should alias to another tokenizer
  1790. # via the converter mapping (e.g., FNetTokenizer -> AlbertTokenizer via AlbertConverter)
  1791. value = None
  1792. if name.endswith("Tokenizer") or name.endswith("TokenizerFast"):
  1793. # Strip "Fast" suffix for converter lookup if present
  1794. lookup_name = name[:-4] if name.endswith("TokenizerFast") else name
  1795. try:
  1796. # Lazy import to avoid circular dependencies
  1797. from ..convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS
  1798. # Check if this tokenizer has a converter mapping
  1799. if lookup_name in SLOW_TO_FAST_CONVERTERS:
  1800. converter_class = SLOW_TO_FAST_CONVERTERS[lookup_name]
  1801. # Find which tokenizer class uses the same converter (reverse lookup)
  1802. # Prefer the tokenizer that matches the converter name pattern
  1803. # (e.g., AlbertConverter -> AlbertTokenizer)
  1804. converter_base_name = converter_class.__name__.replace("Converter", "")
  1805. preferred_tokenizer_name = f"{converter_base_name}Tokenizer"
  1806. # Try preferred tokenizer first
  1807. candidate_names = [preferred_tokenizer_name]
  1808. # Then try all other tokenizers with the same converter
  1809. for tokenizer_name, tokenizer_converter in SLOW_TO_FAST_CONVERTERS.items():
  1810. if tokenizer_converter is converter_class and tokenizer_name != lookup_name:
  1811. if tokenizer_name not in candidate_names:
  1812. candidate_names.append(tokenizer_name)
  1813. # Try to import one of the candidate tokenizers
  1814. for candidate_name in candidate_names:
  1815. if candidate_name in self._class_to_module:
  1816. try:
  1817. alias_module = self._get_module(self._class_to_module[candidate_name])
  1818. base_tokenizer_class = getattr(alias_module, candidate_name)
  1819. value = base_tokenizer_class
  1820. # Cache both names for future imports
  1821. setattr(self, candidate_name, base_tokenizer_class)
  1822. if lookup_name != candidate_name:
  1823. setattr(self, lookup_name, value)
  1824. setattr(self, name, value)
  1825. break
  1826. except Exception:
  1827. # If this candidate fails, try the next one
  1828. continue
  1829. else:
  1830. # Candidate not in _class_to_module - might need recursive resolution
  1831. # Try importing it directly to trigger lazy loading
  1832. try:
  1833. # Try to get it from transformers module to trigger lazy loading
  1834. transformers_module = sys.modules.get("transformers")
  1835. if transformers_module and hasattr(transformers_module, candidate_name):
  1836. base_tokenizer_class = getattr(transformers_module, candidate_name)
  1837. value = base_tokenizer_class
  1838. if lookup_name != candidate_name:
  1839. setattr(self, lookup_name, value)
  1840. setattr(self, name, value)
  1841. break
  1842. except Exception:
  1843. continue
  1844. except (ImportError, AttributeError):
  1845. pass
  1846. if value is None:
  1847. for key, values in self._explicit_import_shortcut.items():
  1848. if name in values:
  1849. value = self._get_module(key)
  1850. break
  1851. if value is None:
  1852. raise AttributeError(f"module {self.__name__} has no attribute {name}")
  1853. setattr(self, name, value)
  1854. return value
  1855. def _get_module(self, module_name: str):
  1856. try:
  1857. return importlib.import_module("." + module_name, self.__name__)
  1858. except Exception as e:
  1859. raise e
  1860. def __reduce__(self):
  1861. return (self.__class__, (self._name, self.__file__, self._import_structure))
  1862. class OptionalDependencyNotAvailable(BaseException):
  1863. """Internally used error class for signalling an optional dependency was not found."""
  1864. def direct_transformers_import(path: str, file="__init__.py") -> ModuleType:
  1865. """Imports transformers directly
  1866. Args:
  1867. path (`str`): The path to the source file
  1868. file (`str`, *optional*): The file to join with the path. Defaults to "__init__.py".
  1869. Returns:
  1870. `ModuleType`: The resulting imported module
  1871. """
  1872. name = "transformers"
  1873. location = os.path.join(path, file)
  1874. spec = importlib.util.spec_from_file_location(name, location, submodule_search_locations=[path])
  1875. if spec is not None and spec.loader is not None:
  1876. module = importlib.util.module_from_spec(spec)
  1877. spec.loader.exec_module(module)
  1878. module = sys.modules[name]
  1879. return module
  1880. raise ImportError(f"Could not load module {name} from {location}")
  1881. class VersionComparison(Enum):
  1882. EQUAL = operator.eq
  1883. NOT_EQUAL = operator.ne
  1884. GREATER_THAN = operator.gt
  1885. LESS_THAN = operator.lt
  1886. GREATER_THAN_OR_EQUAL = operator.ge
  1887. LESS_THAN_OR_EQUAL = operator.le
  1888. @staticmethod
  1889. def from_string(version_string: str) -> "VersionComparison":
  1890. string_to_operator = {
  1891. "=": VersionComparison.EQUAL,
  1892. "==": VersionComparison.EQUAL,
  1893. "!=": VersionComparison.NOT_EQUAL,
  1894. ">": VersionComparison.GREATER_THAN,
  1895. "<": VersionComparison.LESS_THAN,
  1896. ">=": VersionComparison.GREATER_THAN_OR_EQUAL,
  1897. "<=": VersionComparison.LESS_THAN_OR_EQUAL,
  1898. }
  1899. return string_to_operator[version_string]
  1900. @lru_cache
  1901. def split_package_version(package_version_str) -> tuple[str, str, str]:
  1902. pattern = r"([a-zA-Z0-9_-]+)([!<>=~]+)([0-9.]+)"
  1903. match = re.match(pattern, package_version_str)
  1904. if match:
  1905. return (match.group(1), match.group(2), match.group(3))
  1906. else:
  1907. raise ValueError(f"Invalid package version string: {package_version_str}")
  1908. class Backend:
  1909. def __init__(self, backend_requirement: str):
  1910. self.package_name, self.version_comparison, self.version = split_package_version(backend_requirement)
  1911. if self.package_name not in BACKENDS_MAPPING:
  1912. raise ValueError(
  1913. f"Backends should be defined in the BACKENDS_MAPPING. Offending backend: {self.package_name}"
  1914. )
  1915. def get_installed_version(self) -> str:
  1916. """Return the currently installed version of the backend"""
  1917. is_available, current_version = _is_package_available(self.package_name, return_version=True)
  1918. if not is_available:
  1919. raise RuntimeError(f"Backend {self.package_name} is not available.")
  1920. return current_version
  1921. def is_satisfied(self) -> bool:
  1922. return VersionComparison.from_string(self.version_comparison).value(
  1923. version.parse(self.get_installed_version()), version.parse(self.version)
  1924. )
  1925. def __repr__(self) -> str:
  1926. return f'Backend("{self.package_name}", {VersionComparison[self.version_comparison]}, "{self.version}")'
  1927. @property
  1928. def error_message(self):
  1929. return (
  1930. f"{{0}} requires the {self.package_name} library version {self.version_comparison}{self.version}. That"
  1931. f" library was not found with this version in your environment."
  1932. )
  1933. def requires(*, backends=()):
  1934. """
  1935. This decorator enables two things:
  1936. - Attaching a `__backends` tuple to an object to see what are the necessary backends for it
  1937. to execute correctly without instantiating it
  1938. - The '@requires' string is used to dynamically import objects
  1939. """
  1940. if not isinstance(backends, (tuple, list)):
  1941. raise TypeError("Backends should be a tuple or list.")
  1942. backends = tuple(backends)
  1943. applied_backends = []
  1944. for backend in backends:
  1945. if backend in BACKENDS_MAPPING:
  1946. applied_backends.append(backend)
  1947. else:
  1948. if any(key in backend for key in ["=", "<", ">"]):
  1949. applied_backends.append(Backend(backend))
  1950. else:
  1951. raise ValueError(f"Backend should be defined in the BACKENDS_MAPPING. Offending backend: {backend}")
  1952. def inner_fn(fun):
  1953. if isinstance(fun, type):
  1954. # For classes, just attach the metadata — don't wrap, as that would
  1955. # turn the class into a plain function and break isinstance checks.
  1956. fun.__backends = applied_backends
  1957. return fun
  1958. @functools.wraps(fun)
  1959. def wrapper(*args, **kwargs):
  1960. requires_backends(fun, applied_backends)
  1961. return fun(*args, **kwargs)
  1962. wrapper.__backends = applied_backends # type: ignore [unresolved-attribute]
  1963. return wrapper
  1964. return inner_fn
  1965. BASE_FILE_REQUIREMENTS = {
  1966. lambda name, content: "modeling_" in name: ("torch",),
  1967. lambda name, content: "tokenization_" in name and name.endswith("_fast"): ("tokenizers",),
  1968. lambda name, content: "image_processing_" in name and "TorchvisionBackend" in content: (
  1969. "vision",
  1970. "torch",
  1971. "torchvision",
  1972. ),
  1973. lambda name, content: "image_processing_" in name: ("vision",),
  1974. lambda name, content: "video_processing_" in name: ("vision", "torch", "torchvision"),
  1975. }
  1976. def fetch__all__(file_content) -> list[str]:
  1977. """
  1978. Returns the content of the __all__ variable in the file content.
  1979. Returns None if not defined, otherwise returns a list of strings.
  1980. """
  1981. if "__all__" not in file_content:
  1982. return []
  1983. start_index = None
  1984. lines = file_content.splitlines()
  1985. for index, line in enumerate(lines):
  1986. if line.startswith("__all__"):
  1987. start_index = index
  1988. # There is no line starting with `__all__`
  1989. if start_index is None:
  1990. return []
  1991. lines = lines[start_index:]
  1992. if not lines[0].startswith("__all__"):
  1993. raise ValueError(
  1994. "fetch__all__ accepts a list of lines, with the first line being the __all__ variable declaration"
  1995. )
  1996. # __all__ is defined on a single line
  1997. if lines[0].endswith("]"):
  1998. return [obj.strip("\"' ") for obj in lines[0].split("=")[1].strip(" []").split(",")]
  1999. # __all__ is defined on multiple lines
  2000. else:
  2001. _all: list[str] = []
  2002. for __all__line_index in range(1, len(lines)):
  2003. if lines[__all__line_index].strip() == "]":
  2004. return _all
  2005. else:
  2006. _all.append(lines[__all__line_index].strip("\"', "))
  2007. return _all
  2008. @lru_cache
  2009. def create_import_structure_from_path(module_path):
  2010. """
  2011. This method takes the path to a file/a folder and returns the import structure.
  2012. If a file is given, it will return the import structure of the parent folder.
  2013. Import structures are designed to be digestible by `_LazyModule` objects. They are
  2014. created from the __all__ definitions in each files as well as the `@require` decorators
  2015. above methods and objects.
  2016. The import structure allows explicit display of the required backends for a given object.
  2017. These backends are specified in two ways:
  2018. 1. Through their `@require`, if they are exported with that decorator. This `@require` decorator
  2019. accepts a `backend` tuple kwarg mentioning which backends are required to run this object.
  2020. 2. If an object is defined in a file with "default" backends, it will have, at a minimum, this
  2021. backend specified. The default backends are defined according to the filename:
  2022. - If a file is named like `modeling_*.py`, it will have a `torch` backend
  2023. - If a file is named like `tokenization_*_fast.py`, it will have a `tokenizers` backend
  2024. - If a file is named like `image_processing*_fast.py`, it will have a `torchvision` + `torch` backend
  2025. Backends serve the purpose of displaying a clear error message to the user in case the backends are not installed.
  2026. Should an object be imported without its required backends being in the environment, any attempt to use the
  2027. object will raise an error mentioning which backend(s) should be added to the environment in order to use
  2028. that object.
  2029. Here's an example of an input import structure at the src.transformers.models level:
  2030. {
  2031. 'albert': {
  2032. frozenset(): {
  2033. 'configuration_albert': {'AlbertConfig'}
  2034. },
  2035. frozenset({'tokenizers'}): {
  2036. 'tokenization_albert_fast': {'AlbertTokenizer'}
  2037. },
  2038. },
  2039. 'align': {
  2040. frozenset(): {
  2041. 'configuration_align': {'AlignConfig', 'AlignTextConfig', 'AlignVisionConfig'},
  2042. 'processing_align': {'AlignProcessor'}
  2043. },
  2044. },
  2045. 'altclip': {
  2046. frozenset(): {
  2047. 'configuration_altclip': {'AltCLIPConfig', 'AltCLIPTextConfig', 'AltCLIPVisionConfig'},
  2048. 'processing_altclip': {'AltCLIPProcessor'},
  2049. }
  2050. }
  2051. }
  2052. """
  2053. import_structure = {}
  2054. if os.path.isfile(module_path):
  2055. module_path = os.path.dirname(module_path)
  2056. adjacent_modules = []
  2057. with os.scandir(module_path) as entries:
  2058. for entry in entries:
  2059. if entry.name == "__pycache__":
  2060. continue
  2061. if entry.is_dir():
  2062. import_structure[entry.name] = create_import_structure_from_path(entry.path)
  2063. elif not entry.name.startswith(("convert_", "modular_")):
  2064. adjacent_modules.append(entry.name)
  2065. # We're only taking a look at files different from __init__.py
  2066. # We could theoretically require things directly from the __init__.py
  2067. # files, but this is not supported at this time.
  2068. if "__init__.py" in adjacent_modules:
  2069. adjacent_modules.remove("__init__.py")
  2070. module_requirements = {}
  2071. for module_name in adjacent_modules:
  2072. # Only modules ending in `.py` are accepted here.
  2073. if not module_name.endswith(".py"):
  2074. continue
  2075. with open(os.path.join(module_path, module_name), encoding="utf-8") as f:
  2076. file_content = f.read()
  2077. # Remove the .py suffix
  2078. module_name = module_name[:-3]
  2079. previous_line = ""
  2080. previous_index = 0
  2081. # Some files have some requirements by default.
  2082. # For example, any file named `modeling_xxx.py`
  2083. # should have torch as a required backend.
  2084. base_requirements = ()
  2085. for check, requirements in BASE_FILE_REQUIREMENTS.items():
  2086. if check(module_name, file_content):
  2087. base_requirements = requirements
  2088. break
  2089. # Objects that have a `@require` assigned to them will get exported
  2090. # with the backends specified in the decorator as well as the file backends.
  2091. exported_objects = set()
  2092. if "@requires" in file_content:
  2093. lines = file_content.split("\n")
  2094. for index, line in enumerate(lines):
  2095. # This allows exporting items with other decorators. We'll take a look
  2096. # at the line that follows at the same indentation level.
  2097. if line.startswith((" ", "\t", "@", ")")) and not line.startswith("@requires"):
  2098. continue
  2099. # Skipping line enables putting whatever we want between the
  2100. # requires() call and the actual class/method definition.
  2101. # This is what enables having # Copied from statements, docs, etc.
  2102. skip_line = False
  2103. if "@requires" in previous_line:
  2104. skip_line = False
  2105. # Backends are defined on the same line as requires
  2106. if "backends" in previous_line:
  2107. try:
  2108. backends_string = previous_line.split("backends=")[1].split("(")[1].split(")")[0]
  2109. except IndexError:
  2110. raise ValueError(
  2111. f"Couldn't parse backends for @requires decorator in file {module_name}:{previous_line}"
  2112. )
  2113. backends = tuple(sorted([b.strip("'\",") for b in backends_string.split(", ") if b]))
  2114. # Backends are defined in the lines following requires, for example such as:
  2115. # @requires(
  2116. # backends=(
  2117. # "sentencepiece",
  2118. # "torch",
  2119. # )
  2120. # )
  2121. #
  2122. # or
  2123. #
  2124. # @requires(
  2125. # backends=(
  2126. # "sentencepiece",
  2127. # )
  2128. # )
  2129. elif "backends" in lines[previous_index + 1]:
  2130. backends = []
  2131. for backend_line in lines[previous_index:index]:
  2132. if "backends" in backend_line:
  2133. backend_line = backend_line.split("=")[1]
  2134. if '"' in backend_line or "'" in backend_line:
  2135. if ", " in backend_line:
  2136. backends.extend(backend.strip("()\"', ") for backend in backend_line.split(", "))
  2137. else:
  2138. backends.append(backend_line.strip("()\"', "))
  2139. # If the line is only a ')', then we reached the end of the backends and we break.
  2140. if backend_line.strip() == ")":
  2141. break
  2142. backends = tuple(backends)
  2143. # No backends are registered for requires
  2144. else:
  2145. backends = ()
  2146. backends = frozenset(backends + base_requirements)
  2147. if backends not in module_requirements:
  2148. module_requirements[backends] = {}
  2149. if module_name not in module_requirements[backends]:
  2150. module_requirements[backends][module_name] = set()
  2151. if not line.startswith("class") and not line.startswith("def"):
  2152. skip_line = True
  2153. else:
  2154. start_index = 6 if line.startswith("class") else 4
  2155. object_name = line[start_index:].split("(")[0].strip(":")
  2156. module_requirements[backends][module_name].add(object_name)
  2157. exported_objects.add(object_name)
  2158. if not skip_line:
  2159. previous_line = line
  2160. previous_index = index
  2161. # All objects that are in __all__ should be exported by default.
  2162. # These objects are exported with the file backends.
  2163. if "__all__" in file_content:
  2164. for _all_object in fetch__all__(file_content):
  2165. if _all_object not in exported_objects:
  2166. backends = frozenset(base_requirements)
  2167. if backends not in module_requirements:
  2168. module_requirements[backends] = {}
  2169. if module_name not in module_requirements[backends]:
  2170. module_requirements[backends][module_name] = set()
  2171. module_requirements[backends][module_name].add(_all_object)
  2172. import_structure = {**module_requirements, **import_structure}
  2173. return import_structure
  2174. def spread_import_structure(nested_import_structure):
  2175. """
  2176. This method takes as input an unordered import structure and brings the required backends at the top-level,
  2177. aggregating modules and objects under their required backends.
  2178. Here's an example of an input import structure at the src.transformers.models level:
  2179. {
  2180. 'albert': {
  2181. frozenset(): {
  2182. 'configuration_albert': {'AlbertConfig'}
  2183. },
  2184. frozenset({'tokenizers'}): {
  2185. 'tokenization_albert_fast': {'AlbertTokenizer'}
  2186. },
  2187. },
  2188. 'align': {
  2189. frozenset(): {
  2190. 'configuration_align': {'AlignConfig', 'AlignTextConfig', 'AlignVisionConfig'},
  2191. 'processing_align': {'AlignProcessor'}
  2192. },
  2193. },
  2194. 'altclip': {
  2195. frozenset(): {
  2196. 'configuration_altclip': {'AltCLIPConfig', 'AltCLIPTextConfig', 'AltCLIPVisionConfig'},
  2197. 'processing_altclip': {'AltCLIPProcessor'},
  2198. }
  2199. }
  2200. }
  2201. Here's an example of an output import structure at the src.transformers.models level:
  2202. {
  2203. frozenset({'tokenizers'}): {
  2204. 'albert.tokenization_albert_fast': {'AlbertTokenizer'}
  2205. },
  2206. frozenset(): {
  2207. 'albert.configuration_albert': {'AlbertConfig'},
  2208. 'align.processing_align': {'AlignProcessor'},
  2209. 'align.configuration_align': {'AlignConfig', 'AlignTextConfig', 'AlignVisionConfig'},
  2210. 'altclip.configuration_altclip': {'AltCLIPConfig', 'AltCLIPTextConfig', 'AltCLIPVisionConfig'},
  2211. 'altclip.processing_altclip': {'AltCLIPProcessor'}
  2212. }
  2213. }
  2214. """
  2215. def propagate_frozenset(unordered_import_structure):
  2216. frozenset_first_import_structure = {}
  2217. for _key, _value in unordered_import_structure.items():
  2218. # If the value is not a dict but a string, no need for custom manipulation
  2219. if not isinstance(_value, dict):
  2220. frozenset_first_import_structure[_key] = _value
  2221. elif any(isinstance(v, frozenset) for v in _value):
  2222. for k, v in _value.items():
  2223. if isinstance(k, frozenset):
  2224. # Here we want to switch around _key and k to propagate k upstream if it is a frozenset
  2225. if k not in frozenset_first_import_structure:
  2226. frozenset_first_import_structure[k] = {}
  2227. if _key not in frozenset_first_import_structure[k]:
  2228. frozenset_first_import_structure[k][_key] = {}
  2229. frozenset_first_import_structure[k][_key].update(v)
  2230. else:
  2231. # If k is not a frozenset, it means that the dictionary is not "level": some keys (top-level)
  2232. # are frozensets, whereas some are not -> frozenset keys are at an unknown depth-level of the
  2233. # dictionary.
  2234. #
  2235. # We recursively propagate the frozenset for this specific dictionary so that the frozensets
  2236. # are at the top-level when we handle them.
  2237. propagated_frozenset = propagate_frozenset({k: v})
  2238. for r_k, r_v in propagated_frozenset.items():
  2239. if isinstance(_key, frozenset):
  2240. if r_k not in frozenset_first_import_structure:
  2241. frozenset_first_import_structure[r_k] = {}
  2242. if _key not in frozenset_first_import_structure[r_k]:
  2243. frozenset_first_import_structure[r_k][_key] = {}
  2244. # _key is a frozenset -> we switch around the r_k and _key
  2245. frozenset_first_import_structure[r_k][_key].update(r_v)
  2246. else:
  2247. if _key not in frozenset_first_import_structure:
  2248. frozenset_first_import_structure[_key] = {}
  2249. if r_k not in frozenset_first_import_structure[_key]:
  2250. frozenset_first_import_structure[_key][r_k] = {}
  2251. # _key is not a frozenset -> we keep the order of r_k and _key
  2252. frozenset_first_import_structure[_key][r_k].update(r_v)
  2253. else:
  2254. frozenset_first_import_structure[_key] = propagate_frozenset(_value)
  2255. return frozenset_first_import_structure
  2256. def flatten_dict(_dict, previous_key=None):
  2257. items = []
  2258. for _key, _value in _dict.items():
  2259. _key = f"{previous_key}.{_key}" if previous_key is not None else _key
  2260. if isinstance(_value, dict):
  2261. items.extend(flatten_dict(_value, _key).items())
  2262. else:
  2263. items.append((_key, _value))
  2264. return dict(items)
  2265. # The tuples contain the necessary backends. We want these first, so we propagate them up the
  2266. # import structure.
  2267. ordered_import_structure = nested_import_structure
  2268. # 6 is a number that gives us sufficient depth to go through all files and foreseeable folder depths
  2269. # while not taking too long to parse.
  2270. for i in range(6):
  2271. ordered_import_structure = propagate_frozenset(ordered_import_structure)
  2272. # We then flatten the dict so that it references a module path.
  2273. flattened_import_structure = {}
  2274. for key, value in ordered_import_structure.copy().items():
  2275. if isinstance(key, str):
  2276. del ordered_import_structure[key]
  2277. else:
  2278. flattened_import_structure[key] = flatten_dict(value)
  2279. return flattened_import_structure
  2280. @lru_cache
  2281. def define_import_structure(module_path: str, prefix: str | None = None) -> IMPORT_STRUCTURE_T:
  2282. """
  2283. This method takes a module_path as input and creates an import structure digestible by a _LazyModule.
  2284. Here's an example of an output import structure at the src.transformers.models level:
  2285. {
  2286. frozenset({'tokenizers'}): {
  2287. 'albert.tokenization_albert_fast': {'AlbertTokenizer'}
  2288. },
  2289. frozenset(): {
  2290. 'albert.configuration_albert': {'AlbertConfig'},
  2291. 'align.processing_align': {'AlignProcessor'},
  2292. 'align.configuration_align': {'AlignConfig', 'AlignTextConfig', 'AlignVisionConfig'},
  2293. 'altclip.configuration_altclip': {'AltCLIPConfig', 'AltCLIPTextConfig', 'AltCLIPVisionConfig'},
  2294. 'altclip.processing_altclip': {'AltCLIPProcessor'}
  2295. }
  2296. }
  2297. The import structure is a dict defined with frozensets as keys, and dicts of strings to sets of objects.
  2298. If `prefix` is not None, it will add that prefix to all keys in the returned dict.
  2299. """
  2300. import_structure = create_import_structure_from_path(module_path)
  2301. spread_dict = spread_import_structure(import_structure)
  2302. if prefix is None:
  2303. return spread_dict
  2304. else:
  2305. spread_dict = {k: {f"{prefix}.{kk}": vv for kk, vv in v.items()} for k, v in spread_dict.items()}
  2306. return spread_dict
  2307. def clear_import_cache() -> None:
  2308. """
  2309. Clear cached Transformers modules to allow reloading modified code.
  2310. This is useful when actively developing/modifying Transformers code.
  2311. """
  2312. # Get all transformers modules
  2313. transformers_modules = [mod_name for mod_name in sys.modules if mod_name.startswith("transformers.")]
  2314. # Remove them from sys.modules
  2315. for mod_name in transformers_modules:
  2316. module = sys.modules[mod_name]
  2317. # Clear _LazyModule caches if applicable
  2318. if isinstance(module, _LazyModule):
  2319. module._objects = {} # Clear cached objects
  2320. del sys.modules[mod_name]
  2321. # Force reload main transformers module
  2322. if "transformers" in sys.modules:
  2323. main_module = sys.modules["transformers"]
  2324. if isinstance(main_module, _LazyModule):
  2325. main_module._objects = {} # Clear cached objects
  2326. importlib.reload(main_module)