__init__.py 105 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001
  1. """
  2. The torch package contains data structures for multi-dimensional
  3. tensors and defines mathematical operations over these tensors.
  4. Additionally, it provides many utilities for efficient serialization of
  5. Tensors and arbitrary types, and other useful utilities.
  6. It has a CUDA counterpart, that enables you to run your tensor computations
  7. on an NVIDIA GPU with compute capability >= 3.0.
  8. """
  9. # mypy: allow-untyped-defs
  10. import builtins
  11. import ctypes
  12. import functools
  13. import glob
  14. import importlib
  15. import inspect
  16. import math
  17. import os
  18. import platform
  19. import sys
  20. import textwrap
  21. import threading
  22. import warnings
  23. from collections.abc import Callable as _Callable
  24. from typing import (
  25. Any as _Any,
  26. get_origin as _get_origin,
  27. overload as _overload,
  28. TYPE_CHECKING,
  29. TypeVar as _TypeVar,
  30. )
  31. from typing_extensions import (
  32. deprecated as _deprecated,
  33. ParamSpec as _ParamSpec,
  34. TypeIs as _TypeIs,
  35. )
  36. # As a bunch of torch.packages internally still have this check
  37. # we need to keep this. @todo: Remove tests that rely on this check as
  38. # they are likely stale.
  39. def _running_with_deploy() -> builtins.bool:
  40. return False
  41. from torch._utils import (
  42. _functionalize_sync as _sync,
  43. _import_dotted_name,
  44. classproperty,
  45. )
  46. from torch._utils_internal import (
  47. get_file_path,
  48. prepare_multiprocessing_environment,
  49. profiler_allow_cudagraph_cupti_lazy_reinit_cuda12,
  50. USE_GLOBAL_DEPS,
  51. USE_RTLD_GLOBAL_WITH_LIBTORCH,
  52. )
  53. from torch.torch_version import __version__ as __version__
  54. if TYPE_CHECKING:
  55. from torch.types import Device, IntLikeType
  56. __all__ = [
  57. "BoolStorage",
  58. "BoolTensor",
  59. "ByteStorage",
  60. "ByteTensor",
  61. "CharStorage",
  62. "CharTensor",
  63. "DoubleStorage",
  64. "DoubleTensor",
  65. "FloatStorage",
  66. "FloatTensor",
  67. "GradScaler",
  68. "IntStorage",
  69. "IntTensor",
  70. "LongStorage",
  71. "LongTensor",
  72. "ShortStorage",
  73. "ShortTensor",
  74. "SymBool",
  75. "SymFloat",
  76. "SymInt",
  77. "Tensor",
  78. "TypedStorage",
  79. "UntypedStorage",
  80. "are_deterministic_algorithms_enabled",
  81. "autocast",
  82. "chunk",
  83. "compile",
  84. "cond",
  85. "enable_grad",
  86. "export",
  87. "get_default_device",
  88. "get_deterministic_debug_mode",
  89. "get_device_module",
  90. "get_float32_matmul_precision",
  91. "get_rng_state",
  92. "inference_mode",
  93. "initial_seed",
  94. "is_deterministic_algorithms_warn_only_enabled",
  95. "is_storage",
  96. "is_tensor",
  97. "is_warn_always_enabled",
  98. "load",
  99. "lobpcg",
  100. "manual_seed",
  101. "matmul",
  102. "no_grad",
  103. "rand",
  104. "randn",
  105. "save",
  106. "seed",
  107. "set_default_device",
  108. "set_default_tensor_type",
  109. "set_deterministic_debug_mode",
  110. "set_float32_matmul_precision",
  111. "set_printoptions",
  112. "set_rng_state",
  113. "set_warn_always",
  114. "split",
  115. "stack",
  116. "sym_float",
  117. "sym_fresh_size",
  118. "sym_int",
  119. "sym_ite",
  120. "sym_max",
  121. "sym_min",
  122. "sym_not",
  123. "sym_sum",
  124. "typename",
  125. "unravel_index",
  126. "use_deterministic_algorithms",
  127. "vmap",
  128. ]
  129. # Please keep this list sorted
  130. if __all__ != sorted(__all__):
  131. raise AssertionError("__all__ must be kept sorted")
  132. ################################################################################
  133. # Load the extension module
  134. ################################################################################
  135. # If PyTorch was built against the ROCm runtime wheels, then there will be
  136. # a _rocm_init module and it will define an initialize() function which can
  137. # prepare ROCm for use. See general documentation on ROCm runtime wheels:
  138. # https://github.com/ROCm/TheRock/blob/main/docs/packaging/python_packaging.md
  139. # Since this module is only ever added to the wheel if built for such a
  140. # deployment, it is always safe to attempt.
  141. try:
  142. from . import _rocm_init # type: ignore[attr-defined]
  143. except ImportError:
  144. pass
  145. else:
  146. _rocm_init.initialize()
  147. del _rocm_init
  148. if sys.platform == "win32":
  149. def _load_dll_libraries() -> None:
  150. import sysconfig
  151. from torch.version import cuda as cuda_version
  152. pfiles_path = os.getenv("ProgramFiles", r"C:\Program Files")
  153. py_dll_path = os.path.join(sys.exec_prefix, "Library", "bin")
  154. th_dll_path = os.path.join(os.path.dirname(__file__), "lib")
  155. usebase_path = os.path.join(
  156. sysconfig.get_config_var("userbase"), "Library", "bin"
  157. )
  158. py_root_bin_path = os.path.join(sys.exec_prefix, "bin")
  159. # When users create a virtualenv that inherits the base environment,
  160. # we will need to add the corresponding library directory into
  161. # DLL search directories. Otherwise, it will rely on `PATH` which
  162. # is dependent on user settings.
  163. if sys.exec_prefix != sys.base_exec_prefix:
  164. base_py_dll_path = os.path.join(sys.base_exec_prefix, "Library", "bin")
  165. else:
  166. base_py_dll_path = ""
  167. dll_paths = [
  168. p
  169. for p in (
  170. th_dll_path,
  171. py_dll_path,
  172. base_py_dll_path,
  173. usebase_path,
  174. py_root_bin_path,
  175. )
  176. if os.path.exists(p)
  177. ]
  178. if not builtins.any(
  179. os.path.exists(os.path.join(p, "nvToolsExt64_1.dll")) for p in dll_paths
  180. ):
  181. nvtoolsext_dll_path = os.path.join(
  182. os.getenv(
  183. "NVTOOLSEXT_PATH",
  184. os.path.join(pfiles_path, "NVIDIA Corporation", "NvToolsExt"),
  185. ),
  186. "bin",
  187. "x64",
  188. )
  189. else:
  190. nvtoolsext_dll_path = ""
  191. if cuda_version and builtins.all(
  192. not glob.glob(os.path.join(p, "cudart64*.dll")) for p in dll_paths
  193. ):
  194. cuda_version_1 = cuda_version.replace(".", "_")
  195. cuda_path_var = "CUDA_PATH_V" + cuda_version_1
  196. default_path = os.path.join(
  197. pfiles_path, "NVIDIA GPU Computing Toolkit", "CUDA", f"v{cuda_version}"
  198. )
  199. cuda_path = os.path.join(os.getenv(cuda_path_var, default_path), "bin")
  200. else:
  201. cuda_path = ""
  202. dll_paths.extend(
  203. p for p in (nvtoolsext_dll_path, cuda_path) if os.path.exists(p)
  204. )
  205. kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
  206. with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
  207. prev_error_mode = kernel32.SetErrorMode(0x0001)
  208. kernel32.LoadLibraryW.restype = ctypes.c_void_p
  209. if with_load_library_flags:
  210. kernel32.LoadLibraryExW.restype = ctypes.c_void_p
  211. for dll_path in dll_paths:
  212. os.add_dll_directory(dll_path)
  213. try:
  214. ctypes.CDLL("vcruntime140.dll")
  215. ctypes.CDLL("msvcp140.dll")
  216. if platform.machine() != "ARM64":
  217. ctypes.CDLL("vcruntime140_1.dll")
  218. except OSError:
  219. print(
  220. textwrap.dedent(
  221. """
  222. Microsoft Visual C++ Redistributable is not installed, this may lead to the DLL load failure.
  223. It can be downloaded at https://aka.ms/vs/17/release/vc_redist.x64.exe
  224. """
  225. ).strip()
  226. )
  227. dlls = glob.glob(os.path.join(th_dll_path, "*.dll"))
  228. path_patched = False
  229. for dll in dlls:
  230. is_loaded = False
  231. if with_load_library_flags:
  232. res = kernel32.LoadLibraryExW(dll, None, 0x00001100)
  233. last_error = ctypes.get_last_error()
  234. if res is None and last_error != 126:
  235. err = ctypes.WinError(last_error)
  236. err.strerror += (
  237. f' Error loading "{dll}" or one of its dependencies.'
  238. )
  239. raise err
  240. elif res is not None:
  241. is_loaded = True
  242. if not is_loaded:
  243. if not path_patched:
  244. os.environ["PATH"] = ";".join(dll_paths + [os.environ["PATH"]])
  245. path_patched = True
  246. res = kernel32.LoadLibraryW(dll)
  247. if res is None:
  248. err = ctypes.WinError(ctypes.get_last_error())
  249. err.strerror += (
  250. f' Error loading "{dll}" or one of its dependencies.'
  251. )
  252. raise err
  253. kernel32.SetErrorMode(prev_error_mode)
  254. _load_dll_libraries()
  255. del _load_dll_libraries
  256. def _get_cuda_dep_paths(path: str, lib_folder: str, lib_name: str) -> list[str]:
  257. # Libraries can either be in
  258. # path/nvidia/lib_folder/lib or
  259. # path/nvidia/cuXX/lib (since CUDA 13.0) or
  260. # path/lib_folder/lib
  261. from torch.version import cuda as cuda_version
  262. nvidia_lib_paths = glob.glob(
  263. os.path.join(path, "nvidia", lib_folder, "lib", lib_name)
  264. )
  265. if cuda_version is not None:
  266. maj_cuda_version = cuda_version.split(".")[0]
  267. nvidia_lib_paths += glob.glob(
  268. os.path.join(path, "nvidia", f"cu{maj_cuda_version}", "lib", lib_name)
  269. )
  270. lib_paths = glob.glob(os.path.join(path, lib_folder, "lib", lib_name))
  271. return nvidia_lib_paths + lib_paths
  272. def _preload_cuda_lib(lib_folder: str, lib_name: str, required: bool = True) -> None: # type: ignore[valid-type]
  273. """Preloads cuda library if it could not be found otherwise."""
  274. # Should only be called on Linux if default path resolution have failed
  275. if platform.system() != "Linux":
  276. raise AssertionError(f"Should only be called on Linux, got {platform.system()}")
  277. lib_path = None
  278. for path in sys.path:
  279. candidate_lib_paths = _get_cuda_dep_paths(path, lib_folder, lib_name)
  280. if candidate_lib_paths:
  281. lib_path = candidate_lib_paths[0]
  282. break
  283. if not lib_path and required:
  284. raise ValueError(f"{lib_name} not found in the system path {sys.path}")
  285. if lib_path:
  286. ctypes.CDLL(lib_path)
  287. def _preload_cuda_deps(err: OSError | None = None) -> None:
  288. cuda_libs: list[tuple[str, str]] = [
  289. # NOTE: Order matters! We must preload libcublasLt BEFORE libcublas to prevent
  290. # libcublas from loading a mismatched system-wide libcublasLt via its RUNPATH.
  291. # Without this, if a different CUDA Toolkit version exists in the system PATH,
  292. # libcublas may load the wrong libcublasLt, causing symbol errors or runtime failures.
  293. ("cublas", "libcublasLt.so.*[0-9]"),
  294. ("cublas", "libcublas.so.*[0-9]"),
  295. ("cudnn", "libcudnn.so.*[0-9]"),
  296. ("cuda_nvrtc", "libnvrtc.so.*[0-9]"),
  297. ("cuda_nvrtc", "libnvrtc-builtins.so.*[0-9]"),
  298. ("cuda_runtime", "libcudart.so.*[0-9]"),
  299. ("cuda_cupti", "libcupti.so.*[0-9]"),
  300. ("cufft", "libcufft.so.*[0-9]"),
  301. ("curand", "libcurand.so.*[0-9]"),
  302. ("nvjitlink", "libnvJitLink.so.*[0-9]"),
  303. ("cusparse", "libcusparse.so.*[0-9]"),
  304. ("cusparselt", "libcusparseLt.so.*[0-9]"),
  305. ("cusolver", "libcusolver.so.*[0-9]"),
  306. ("nccl", "libnccl.so.*[0-9]"),
  307. ("nvshmem", "libnvshmem_host.so.*[0-9]"),
  308. ("cufile", "libcufile.so.*[0-9]"),
  309. ]
  310. # If error is passed, re-raise it if it's not about one of the abovementioned
  311. # libraries
  312. if err is not None and not [
  313. lib for _, lib in cuda_libs if lib.split(".", 1)[0] in err.args[0]
  314. ]:
  315. raise err
  316. # Otherwise, try to preload dependencies from site-packages
  317. for lib_folder, lib_name in cuda_libs:
  318. _preload_cuda_lib(lib_folder, lib_name)
  319. # libnvToolsExt is Optional Dependency
  320. _preload_cuda_lib("nvtx", "libnvToolsExt.so.*[0-9]", required=False)
  321. # See Note [Global dependencies]
  322. def _load_global_deps() -> None:
  323. if platform.system() == "Windows":
  324. return
  325. # Determine the file extension based on the platform
  326. lib_ext = ".dylib" if platform.system() == "Darwin" else ".so"
  327. lib_name = f"libtorch_global_deps{lib_ext}"
  328. here = os.path.abspath(__file__)
  329. global_deps_lib_path = os.path.join(os.path.dirname(here), "lib", lib_name)
  330. try:
  331. ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL)
  332. # Workaround slim-wheel CUDA dependency bugs in cusparse and cudnn by preloading nvjitlink
  333. # and nvrtc. In CUDA-12.4+ cusparse depends on nvjitlink, but does not have rpath when
  334. # shipped as wheel, which results in OS picking wrong/older version of nvjitlink library
  335. # if `LD_LIBRARY_PATH` is defined, see https://github.com/pytorch/pytorch/issues/138460
  336. # Similar issue exist in cudnn that dynamically loads nvrtc, unaware of its relative path.
  337. # See https://github.com/pytorch/pytorch/issues/145580
  338. try:
  339. with open("/proc/self/maps") as f:
  340. _maps = f.read()
  341. # libtorch_global_deps.so always depends in cudart, check if its installed and loaded
  342. if "libcudart.so" not in _maps:
  343. return
  344. # If all above-mentioned conditions are met, preload CUDA dependencies
  345. _preload_cuda_deps()
  346. except Exception:
  347. pass
  348. except OSError as err:
  349. # Can happen for wheel with cuda libs as PYPI deps
  350. # As PyTorch is not purelib, but nvidia-*-cu12 is
  351. _preload_cuda_deps(err)
  352. ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL)
  353. if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv("TORCH_USE_RTLD_GLOBAL")) and (
  354. platform.system() != "Windows"
  355. ):
  356. # Do it the hard way. You might want to load libtorch with RTLD_GLOBAL in a
  357. # few circumstances:
  358. #
  359. # 1. You're in a build environment (e.g., fbcode) where
  360. # libtorch_global_deps is not available, but you still need
  361. # to get mkl to link in with RTLD_GLOBAL or it will just
  362. # not work.
  363. #
  364. # 2. You're trying to run PyTorch under UBSAN and you need
  365. # to ensure that only one copy of libtorch is loaded, so
  366. # vptr checks work properly
  367. #
  368. # If you're using this setting, you must verify that all the libraries
  369. # you load consistently use the same libstdc++, or you may have
  370. # mysterious segfaults.
  371. #
  372. old_flags = sys.getdlopenflags()
  373. sys.setdlopenflags(os.RTLD_GLOBAL | os.RTLD_LAZY)
  374. from torch._C import * # noqa: F403
  375. sys.setdlopenflags(old_flags)
  376. del old_flags
  377. else:
  378. # Easy way. You want this most of the time, because it will prevent
  379. # C++ symbols from libtorch clobbering C++ symbols from other
  380. # libraries, leading to mysterious segfaults.
  381. #
  382. # If building in an environment where libtorch_global_deps isn't available
  383. # like parts of fbsource, but where RTLD_GLOBAL causes segfaults, you will
  384. # want USE_RTLD_GLOBAL_WITH_LIBTORCH = False and USE_GLOBAL_DEPS = False
  385. #
  386. # See Note [Global dependencies]
  387. if USE_GLOBAL_DEPS:
  388. _load_global_deps()
  389. from torch._C import * # noqa: F403
  390. class SymInt:
  391. """
  392. Like an int (including magic methods), but redirects all operations on the
  393. wrapped node. This is used in particular to symbolically record operations
  394. in the symbolic shape workflow.
  395. """
  396. def __init__(self, node):
  397. # This field MUST be named node; C++ binding code assumes that this
  398. # class has a field named node that stores SymNode
  399. self.node = node
  400. def __bool__(self):
  401. return builtins.bool(self != 0)
  402. def __int__(self):
  403. return self.node.int_()
  404. def __index__(self):
  405. return self.node.int_()
  406. # Magic methods installed by torch.fx.experimental.sym_node
  407. def __round__(self, ndigits=None):
  408. return self
  409. def __truediv__(self, other):
  410. if isinstance(other, (builtins.float, SymFloat)):
  411. return sym_float(self).__float_truediv__(other)
  412. if not isinstance(other, (builtins.int, SymInt)):
  413. return NotImplemented
  414. return self.__int_truediv__(other)
  415. def __rtruediv__(self, other):
  416. if isinstance(other, (builtins.float, SymFloat)):
  417. return sym_float(self).__rfloat_truediv__(other)
  418. if not isinstance(other, (builtins.int, SymInt)):
  419. return NotImplemented
  420. return self.__rint_truediv__(other)
  421. def __floordiv__(self, other):
  422. if isinstance(other, (builtins.float, SymFloat)):
  423. return sym_float(math.floor(sym_float(self) / other))
  424. if not isinstance(other, (builtins.int, SymInt)):
  425. return NotImplemented
  426. return self.__int_floordiv__(other)
  427. def __rfloordiv__(self, other):
  428. if isinstance(other, (builtins.float, SymFloat)):
  429. return sym_float(math.floor(other / sym_float(self)))
  430. if not isinstance(other, (builtins.int, SymInt)):
  431. return NotImplemented
  432. return self.__rint_floordiv__(other)
  433. # nb: complex is impossible to handle correctly lol, with
  434. # negative base and integral float need to diverge semantics and
  435. # just always return complex. Neener neener pretend this problem
  436. # doesn't exist
  437. def __pow__(self, other):
  438. if isinstance(other, (builtins.float, SymFloat)):
  439. return sym_float(self).__pow__(other)
  440. if not isinstance(other, (builtins.int, SymInt)):
  441. return NotImplemented
  442. # Guards! This guard is necessary because we need to know it to
  443. # determine the output type of this operation
  444. if other >= 0:
  445. return self.__pow_by_natural__(other)
  446. else:
  447. # Mercifully, when the exponent is negative, Python just promotes
  448. # to doubles and does a float pow:
  449. #
  450. # if (Py_SIZE(b) < 0 && c == NULL) {
  451. # /* if exponent is negative and there's no modulus:
  452. # return a float. This works because we know
  453. # that this calls float_pow() which converts its
  454. # arguments to double. */
  455. # Py_DECREF(a);
  456. # Py_DECREF(b);
  457. # return PyFloat_Type.tp_as_number->nb_power(v, w, x);
  458. # }
  459. return sym_float(self).__pow__(sym_float(other))
  460. def __rpow__(self, other):
  461. if isinstance(other, (builtins.float, SymFloat)):
  462. return sym_float(self).__rpow__(other)
  463. if not isinstance(other, (builtins.int, SymInt)):
  464. return NotImplemented
  465. if self >= 0: # self is exponent
  466. return self.__rpow_by_natural__(other)
  467. else:
  468. return sym_float(self).__rpow__(sym_float(other))
  469. def __eq__(self, other: object) -> builtins.bool:
  470. raise TypeError("type stub not overridden")
  471. def __lt__(self, other) -> builtins.bool:
  472. raise TypeError("type stub not overridden")
  473. def __gt__(self, other) -> builtins.bool:
  474. raise TypeError("type stub not overridden")
  475. def __le__(self, other) -> builtins.bool:
  476. raise TypeError("type stub not overridden")
  477. def __ge__(self, other) -> builtins.bool:
  478. raise TypeError("type stub not overridden")
  479. def __add__(self, other) -> "SymInt":
  480. raise TypeError("type stub not overridden")
  481. def __radd__(self, other) -> "SymInt":
  482. raise TypeError("type stub not overridden")
  483. def __rmul__(self, other) -> "SymInt":
  484. raise TypeError("type stub not overridden")
  485. def __mod__(self, other: "IntLikeType") -> "SymInt":
  486. raise TypeError("type stub not overridden")
  487. def __mul__(self, other) -> "SymInt":
  488. raise TypeError("type stub not overridden")
  489. def __pow_by_natural__(self, other) -> "SymInt":
  490. raise TypeError("type stub not overridden")
  491. def __rpow_by_natural__(self, other) -> "SymInt":
  492. raise TypeError("type stub not overridden")
  493. def __int_truediv__(self, other) -> "SymFloat":
  494. raise TypeError("type stub not overridden")
  495. def __rint_truediv__(self, other) -> "SymFloat":
  496. raise TypeError("type stub not overridden")
  497. def __int_floordiv__(self, other) -> "SymFloat":
  498. raise TypeError("type stub not overridden")
  499. def __rint_floordiv__(self, other) -> "SymFloat":
  500. raise TypeError("type stub not overridden")
  501. def __sym_max__(self, other):
  502. raise TypeError("type stub not overridden")
  503. def __sym_min__(self, other):
  504. raise TypeError("type stub not overridden")
  505. def __sym_float__(self):
  506. raise TypeError("type stub not overridden")
  507. def __neg__(self):
  508. raise TypeError("type stub not overridden")
  509. def __sub__(self, other: "IntLikeType") -> "SymInt":
  510. raise TypeError("type stub not overridden")
  511. def __rsub__(self, other: "IntLikeType") -> "SymInt":
  512. raise TypeError("type stub not overridden")
  513. def __and__(self, other) -> "SymInt":
  514. raise TypeError("type stub not overridden")
  515. def __or__(self, other) -> "SymInt":
  516. raise TypeError("type stub not overridden")
  517. def __repr__(self):
  518. return self.node._graph_repr()
  519. def _sympy_(self):
  520. return self.node.expr
  521. def __hash__(self) -> builtins.int:
  522. if self.node.is_nested_int():
  523. return hash(self.node.nested_int())
  524. else:
  525. # We could support constant SymInts as well, but not doing it for now
  526. raise TypeError("unhashable type: non-nested SymInt")
  527. # TODO: Force specialization
  528. # This can't be done because the TypeError here is load bearing
  529. # for einops
  530. # https://github.com/arogozhnikov/einops/blob/6181e1e95dc58c00a3143c1726da1c6ee0463164/einops/einops.py#L237
  531. # return hash(builtins.int(self))
  532. def as_integer_ratio(self) -> tuple["SymInt", builtins.int]:
  533. """Represent this int as an exact integer ratio"""
  534. return self, 1
  535. def bit_length(self) -> builtins.int:
  536. # TODO: A more relaxed guard is possible here, where you guard to
  537. # allow all integer quantities which would result in the same bit
  538. # length. We can also just make a dedicated Sympy function for
  539. # computing this quantity and represent it symbolically.
  540. return builtins.int(self).bit_length()
  541. def conjugate(self) -> "SymInt":
  542. return self
  543. class SymFloat:
  544. """
  545. Like a float (including magic methods), but redirects all operations on the
  546. wrapped node. This is used in particular to symbolically record operations
  547. in the symbolic shape workflow.
  548. """
  549. def __init__(self, node):
  550. # This field MUST be named node; C++ binding code assumes that this
  551. # class has a field named node that stores SymNode
  552. self.node = node
  553. def __truediv__(self, other):
  554. if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
  555. return NotImplemented
  556. return self.__float_truediv__(sym_float(other))
  557. def __rtruediv__(self, other):
  558. if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
  559. return NotImplemented
  560. return self.__rfloat_truediv__(sym_float(other))
  561. def __floordiv__(self, other):
  562. if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
  563. return NotImplemented
  564. return sym_float(math.floor(self / sym_float(other)))
  565. def __rfloordiv__(self, other):
  566. if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
  567. return NotImplemented
  568. return sym_float(math.floor(sym_float(other) / self))
  569. def __bool__(self):
  570. return self.node.bool_()
  571. def __float__(self):
  572. return self.node.guard_float("", 0)
  573. def __int__(self):
  574. return self.__trunc__().__int__()
  575. # Symbolic power does NOT work with negative base, this is to avoid
  576. # potential complex outputs
  577. def __pow__(self, other):
  578. if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
  579. return NotImplemented
  580. torch._check(self >= 0)
  581. return self.__float_pow__(other)
  582. def __rpow__(self, other):
  583. if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
  584. return NotImplemented
  585. torch._check(other >= 0)
  586. return self.__rfloat_pow__(other)
  587. # Magic methods installed by torch.fx.experimental.sym_node
  588. def __eq__(self, other: object) -> builtins.bool:
  589. raise TypeError("type stub not overridden")
  590. def __lt__(self, other) -> builtins.bool:
  591. raise TypeError("type stub not overridden")
  592. def __gt__(self, other) -> builtins.bool:
  593. raise TypeError("type stub not overridden")
  594. def __le__(self, other) -> builtins.bool:
  595. raise TypeError("type stub not overridden")
  596. def __ge__(self, other) -> builtins.bool:
  597. raise TypeError("type stub not overridden")
  598. def __float_pow__(self, other) -> "SymFloat":
  599. raise TypeError("type stub not overridden")
  600. def __rfloat_pow__(self, other) -> "SymFloat":
  601. raise TypeError("type stub not overridden")
  602. def __float_truediv__(self, other) -> "SymFloat":
  603. raise TypeError("type stub not overridden")
  604. def __rfloat_truediv__(self, other) -> "SymFloat":
  605. raise TypeError("type stub not overridden")
  606. def __trunc__(self):
  607. raise TypeError("type stub not overridden")
  608. def __sym_max__(self, other):
  609. raise TypeError("type stub not overridden")
  610. def __sym_min__(self, other):
  611. raise TypeError("type stub not overridden")
  612. def __sym_int__(self):
  613. raise TypeError("type stub not overridden")
  614. def is_integer(self):
  615. """Return True if the float is an integer."""
  616. raise TypeError("type stub not overridden")
  617. def as_integer_ratio(self) -> tuple[builtins.int, builtins.int]:
  618. """Represent this float as an exact integer ratio"""
  619. return builtins.float(self).as_integer_ratio()
  620. def __repr__(self):
  621. return self.node._graph_repr()
  622. def _sympy_(self):
  623. return self.node.expr
  624. def __hash__(self):
  625. return hash(builtins.float(self))
  626. def conjugate(self) -> "SymFloat":
  627. """Returns the complex conjugate of the float."""
  628. return self
  629. def hex(self) -> str:
  630. """Returns the hexadecimal representation of the float."""
  631. return self.node.guard_float("", 0).hex()
  632. class SymBool:
  633. """
  634. Like a bool (including magic methods), but redirects all operations on the
  635. wrapped node. This is used in particular to symbolically record operations
  636. in the symbolic shape workflow.
  637. Unlike regular bools, regular boolean operators will force extra guards instead
  638. of symbolically evaluate. Use the bitwise operators instead to handle this.
  639. """
  640. def __init__(self, node):
  641. # This field MUST be named node; C++ binding code assumes that this
  642. # class has a field named node that stores SymNode
  643. self.node = node
  644. def __bool__(self):
  645. return self.node.bool_()
  646. def __int__(self):
  647. return builtins.int(self.node.bool_())
  648. # Magic methods installed by torch.fx.experimental.sym_node
  649. def __and__(self, other) -> "SymBool":
  650. raise TypeError("type stub not overridden")
  651. def __or__(self, other) -> "SymBool":
  652. raise TypeError("type stub not overridden")
  653. # We very carefully define __sym_not__, and not a number of other
  654. # plausible alternatives:
  655. #
  656. # - We do not override __not__ because this is not a real magic
  657. # method; you cannot override the meaning of the not builtin in
  658. # Python. We use the name 'sym_not' to clarify that in user code you
  659. # cannot use the builtin not or operator.not_ or operator.__not__ and
  660. # hit this magic method; you must use our custom sym_not operator.
  661. #
  662. # - We do not override the __invert__ method because SymBool is
  663. # meant to be usable in situations where bool is expected. However,
  664. # bitwise negation ~a does the wrong thing with booleans (because
  665. # bool is a subclass of int, so ~1 = -2 which is not falseish.)
  666. # This would be a giant footgun, so we get around it by defining
  667. # our own operator. Note that bitwise and/or do the right thing,
  668. # so we reuse the conventional operators there for readability.
  669. #
  670. def __sym_not__(self) -> "SymBool":
  671. raise TypeError("type stub not overridden")
  672. def __sym_ite__(self, then_val, else_val):
  673. raise TypeError("type stub not overridden")
  674. def __eq__(self, other) -> builtins.bool:
  675. raise TypeError("type stub not overridden")
  676. def __repr__(self):
  677. return self.node._graph_repr()
  678. def _sympy_(self):
  679. return self.node.expr
  680. def __hash__(self):
  681. if self.node.is_constant():
  682. return hash(self.node.bool_())
  683. else:
  684. # Force specialization
  685. return hash(builtins.bool(self))
  686. def __sym_float__(self):
  687. """
  688. Provides a SymFloat representation (0.0 or 1.0) for this SymBool.
  689. Called by torch.sym_float() when casting SymBool to float.
  690. """
  691. from torch.fx.experimental.sym_node import wrap_node
  692. return wrap_node(self.node.sym_float())
  693. def sym_not(a):
  694. r"""SymInt-aware utility for logical negation.
  695. Args:
  696. a (SymBool or bool): Object to negate
  697. """
  698. import sympy
  699. if overrides.has_torch_function_unary(a):
  700. return overrides.handle_torch_function(sym_not, (a,), a)
  701. if hasattr(a, "__sym_not__"):
  702. return a.__sym_not__()
  703. if isinstance(a, sympy.Basic):
  704. return ~a # type: ignore[operator]
  705. return not a
  706. def sym_float(a):
  707. r"""SymInt-aware utility for float casting.
  708. Args:
  709. a (SymInt, SymFloat, or object): Object to cast
  710. """
  711. if overrides.has_torch_function_unary(a):
  712. return overrides.handle_torch_function(sym_float, (a,), a)
  713. if isinstance(a, SymFloat):
  714. return a
  715. elif hasattr(a, "__sym_float__"):
  716. return a.__sym_float__()
  717. return builtins.float(a) # type: ignore[operator]
  718. def sym_int(a):
  719. r"""SymInt-aware utility for int casting.
  720. Args:
  721. a (SymInt, SymFloat, or object): Object to cast
  722. """
  723. if overrides.has_torch_function_unary(a):
  724. return overrides.handle_torch_function(sym_int, (a,), a)
  725. if isinstance(a, SymInt):
  726. return a
  727. elif isinstance(a, SymFloat):
  728. return math.trunc(a)
  729. return builtins.int(a) # type: ignore[operator]
  730. def sym_max(a, b):
  731. """
  732. SymInt-aware utility for max which avoids branching on a < b.
  733. Unlike builtins.max(), this only works for int/float, and it always
  734. promotes to float if any argument is float (unlike builtins.max, which
  735. will faithfully preserve the type of the input argument).
  736. """
  737. if overrides.has_torch_function((a, b)):
  738. return overrides.handle_torch_function(sym_max, (a, b), a, b)
  739. if isinstance(a, (SymInt, SymFloat)):
  740. return a.__sym_max__(b)
  741. elif isinstance(b, (SymInt, SymFloat)):
  742. # Due to promotion semantics, this is operator is commutative:
  743. # max(1, 1.0) === max(1.0, 1) === 1.0
  744. return b.__sym_max__(a)
  745. # TODO: Probably can make bool work too, just lazy
  746. all_types, float_types = __all_and_float_types()
  747. if not isinstance(a, all_types):
  748. raise AssertionError(f"expected {all_types}, got {type(a)}")
  749. if not isinstance(b, all_types):
  750. raise AssertionError(f"expected {all_types}, got {type(b)}")
  751. if isinstance(a, float_types) or isinstance(b, float_types):
  752. return builtins.float(builtins.max(a, b)) # type: ignore[call-overload]
  753. else:
  754. return builtins.max(a, b) # type: ignore[call-overload]
  755. def __all_and_float_types() -> tuple[tuple[type, ...], tuple[type, ...]]:
  756. try:
  757. import numpy as np
  758. all_types: tuple[type, ...] = (
  759. np.integer,
  760. np.floating,
  761. builtins.int,
  762. builtins.float,
  763. )
  764. float_types: tuple[type, ...] = (np.floating, builtins.float)
  765. except ModuleNotFoundError:
  766. all_types = (builtins.int, builtins.float)
  767. float_types = (builtins.float,)
  768. return all_types, float_types
  769. def sym_min(a, b):
  770. """SymInt-aware utility for min()."""
  771. if overrides.has_torch_function((a, b)):
  772. return overrides.handle_torch_function(sym_min, (a, b), a, b)
  773. if isinstance(a, (SymInt, SymFloat)):
  774. return a.__sym_min__(b)
  775. elif isinstance(b, (SymInt, SymFloat)):
  776. return b.__sym_min__(a)
  777. all_types, float_types = __all_and_float_types()
  778. if not isinstance(a, all_types):
  779. raise AssertionError(f"expected {all_types}, got {type(a)}")
  780. if not isinstance(b, all_types):
  781. raise AssertionError(f"expected {all_types}, got {type(b)}")
  782. if isinstance(a, float_types) or isinstance(b, float_types):
  783. return builtins.float(builtins.min(a, b)) # type: ignore[call-overload]
  784. else:
  785. return builtins.min(a, b) # type: ignore[call-overload]
  786. def sym_sum(args):
  787. """
  788. N-ary add which is faster to compute for long lists than iterated binary
  789. addition. Only does something special for integers.
  790. """
  791. if overrides.has_torch_function(args):
  792. return overrides.handle_torch_function(sym_sum, args, args)
  793. found = None
  794. for a in args:
  795. if not isinstance(a, (SymInt, builtins.int)):
  796. return builtins.sum(args)
  797. if isinstance(a, SymInt):
  798. found = a.node
  799. if found is None:
  800. return builtins.sum(args)
  801. from torch.fx.experimental.sym_node import to_node, wrap_node
  802. return wrap_node(found.sym_sum(tuple(to_node(found, a) for a in args)))
  803. # Drop in replacement for math.sqrt, math.sin, math.cos etc
  804. def _get_sym_math_fn(name):
  805. def fn(a):
  806. if overrides.has_torch_function_unary(a):
  807. return overrides.handle_torch_function(fn, (a,), a)
  808. if isinstance(a, SymInt):
  809. a = torch.sym_float(a)
  810. if hasattr(a, f"__sym_{name}__"):
  811. return getattr(a, f"__sym_{name}__")()
  812. return getattr(math, name)(a)
  813. return fn
  814. __fn, __name, __sym_name = None, "", ""
  815. for __name in (
  816. "sqrt",
  817. "cos",
  818. "cosh",
  819. "sin",
  820. "sinh",
  821. "tan",
  822. "tanh",
  823. "asin",
  824. "acos",
  825. "atan",
  826. "log2",
  827. ):
  828. __sym_name = f"_sym_{__name}"
  829. __fn = _get_sym_math_fn(__name)
  830. __fn.__qualname__ = __fn.__name__ = __sym_name
  831. globals()[__sym_name] = __fn
  832. del __fn, __name, __sym_name, _get_sym_math_fn
  833. # Adding temporary shortcut
  834. sym_sqrt = globals()["_sym_sqrt"]
  835. __all__.append("sym_sqrt")
  836. def sym_ite(b, t, f):
  837. """SymInt-aware utility for ternary operator (``t if b else f``.)"""
  838. if overrides.has_torch_function((b, t, f)):
  839. return overrides.handle_torch_function(sym_ite, (b, t, f), b, t, f)
  840. if not isinstance(b, (SymBool, builtins.bool)):
  841. raise AssertionError(f"expected SymBool or bool, got {type(b)}")
  842. if type(t) is not type(f):
  843. raise AssertionError(f"type mismatch: {type(t)} vs {type(f)}")
  844. if isinstance(b, SymBool):
  845. return b.__sym_ite__(t, f)
  846. return t if b else f
  847. # Create a fresh unbacked int, from an (possibly unbacked int) expression.
  848. def sym_fresh_size(expr):
  849. return torch.tensor(expr).item()
  850. # Check to see if we can load C extensions, and if not provide some guidance
  851. # on what the problem might be.
  852. try:
  853. # _initExtension is chosen (arbitrarily) as a sentinel.
  854. from torch._C import _initExtension
  855. except ImportError:
  856. import torch._C as _C_for_compiled_check
  857. if _C_for_compiled_check.__file__ is None:
  858. raise ImportError(
  859. textwrap.dedent(
  860. """
  861. Failed to load PyTorch C extensions:
  862. It appears that PyTorch has loaded the `torch/_C` folder
  863. of the PyTorch repository rather than the C extensions which
  864. are expected in the `torch._C` namespace. This can occur when
  865. using the `install` workflow. e.g.
  866. $ python -m pip install --no-build-isolation -v . && python -c "import torch"
  867. This error can generally be solved using the `develop` workflow
  868. $ python -m pip install --no-build-isolation -v -e . && python -c "import torch" # This should succeed
  869. or by running Python from a different directory.
  870. """
  871. ).strip()
  872. ) from None
  873. raise # If __file__ is not None the cause is unknown, so just re-raise.
  874. # The torch._C submodule is already loaded via `from torch._C import *` above
  875. # Make an explicit reference to the _C submodule to appease linters
  876. from torch import _C as _C
  877. __name, __obj = "", None
  878. for __name in dir(_C):
  879. if __name[0] != "_" and not __name.endswith("Base"):
  880. __all__.append(__name)
  881. __obj = getattr(_C, __name)
  882. if callable(__obj) or inspect.isclass(__obj):
  883. if __obj.__module__ != __name__: # "torch"
  884. # TODO: fix their module from C++ side
  885. if __name not in {
  886. "DisableTorchFunctionSubclass",
  887. "DisableTorchFunction",
  888. "Generator",
  889. }:
  890. __obj.__module__ = __name__ # "torch"
  891. elif __name == "TensorBase":
  892. # issue 109438 / pr 109940. Prevent TensorBase from being copied into torch.
  893. delattr(sys.modules[__name__], __name)
  894. del __name, __obj
  895. if not TYPE_CHECKING:
  896. # issue 38137 and python issue 43367. Submodules of a C extension are
  897. # non-standard, and attributes of those submodules cannot be pickled since
  898. # pickle expect to be able to import them as "from _C.sub import attr"
  899. # which fails with "_C is not a package
  900. def _import_extension_to_sys_modules(module, memo=None):
  901. if memo is None:
  902. memo = set()
  903. if module in memo:
  904. return
  905. memo.add(module)
  906. module_name = module.__name__
  907. for name in dir(module):
  908. member = getattr(module, name)
  909. member_name = getattr(member, "__name__", "")
  910. if inspect.ismodule(member) and member_name.startswith(module_name):
  911. sys.modules.setdefault(member_name, member)
  912. # Recurse for submodules (e.g., `_C._dynamo.eval_frame`)
  913. _import_extension_to_sys_modules(member, memo)
  914. _import_extension_to_sys_modules(_C)
  915. del _import_extension_to_sys_modules
  916. ################################################################################
  917. # Define basic utilities
  918. ################################################################################
  919. def typename(obj: _Any, /) -> str:
  920. """
  921. String representation of the type of an object.
  922. This function returns a fully qualified string representation of an object's type.
  923. Args:
  924. obj (object): The object whose type to represent
  925. Returns:
  926. str: the type of the object `o`
  927. Example:
  928. >>> x = torch.tensor([1, 2, 3])
  929. >>> torch.typename(x)
  930. 'torch.LongTensor'
  931. >>> torch.typename(torch.nn.Parameter)
  932. 'torch.nn.parameter.Parameter'
  933. """
  934. if isinstance(obj, torch.Tensor):
  935. return obj.type()
  936. module = getattr(obj, "__module__", "") or ""
  937. qualname = ""
  938. if hasattr(obj, "__qualname__"):
  939. qualname = obj.__qualname__
  940. elif hasattr(obj, "__name__"):
  941. qualname = obj.__name__
  942. else:
  943. module = obj.__class__.__module__ or ""
  944. qualname = obj.__class__.__qualname__
  945. if module in {"", "builtins"}:
  946. return qualname
  947. return f"{module}.{qualname}"
  948. def is_tensor(obj: _Any, /) -> _TypeIs["torch.Tensor"]:
  949. r"""Returns True if `obj` is a PyTorch tensor.
  950. Args:
  951. obj (object): Object to test
  952. Example::
  953. >>> x = torch.tensor([1, 2, 3])
  954. >>> torch.is_tensor(x)
  955. True
  956. """
  957. return isinstance(obj, torch.Tensor)
  958. def is_storage(obj: _Any, /) -> builtins.bool:
  959. r"""Returns True if `obj` is a PyTorch storage object.
  960. Args:
  961. obj (Object): Object to test
  962. Example::
  963. >>> import torch
  964. >>> # UntypedStorage (recommended)
  965. >>> tensor = torch.tensor([1, 2, 3])
  966. >>> storage = tensor.untyped_storage()
  967. >>> torch.is_storage(storage)
  968. True
  969. >>>
  970. >>> # TypedStorage (legacy)
  971. >>> typed_storage = torch.TypedStorage(5, dtype=torch.float32)
  972. >>> torch.is_storage(typed_storage)
  973. True
  974. >>>
  975. >>> # regular tensor (should return False)
  976. >>> torch.is_storage(tensor)
  977. False
  978. >>>
  979. >>> # non-storage object
  980. >>> torch.is_storage([1, 2, 3])
  981. False
  982. """
  983. return type(obj) in _storage_classes
  984. _GLOBAL_DEVICE_CONTEXT = threading.local()
  985. def get_default_device() -> "torch.device":
  986. r"""Gets the default ``torch.Tensor`` to be allocated on ``device``"""
  987. global _GLOBAL_DEVICE_CONTEXT
  988. from torch.overrides import _get_current_function_mode_stack
  989. from torch.utils._device import DeviceContext
  990. def _get_device_with_index(device):
  991. if device.index is not None:
  992. return device
  993. else:
  994. # TODO: Call like get_device_index() method corresponding to
  995. # each device type
  996. return torch.tensor([]).device
  997. # Get device from any active DeviceContext.
  998. device_mode = next(
  999. filter(
  1000. lambda mode: isinstance(mode, DeviceContext),
  1001. reversed(_get_current_function_mode_stack()),
  1002. ),
  1003. None,
  1004. )
  1005. if device_mode:
  1006. device = device_mode.device
  1007. return _get_device_with_index(device)
  1008. device_context = getattr(_GLOBAL_DEVICE_CONTEXT, "device_context", None)
  1009. if device_context is not None:
  1010. return _get_device_with_index(device_context.device)
  1011. return torch.device("cpu")
  1012. def set_default_device(device: "Device") -> None:
  1013. """Sets the default ``torch.Tensor`` to be allocated on ``device``. This
  1014. does not affect factory function calls which are called with an explicit
  1015. ``device`` argument. Factory calls will be performed as if they
  1016. were passed ``device`` as an argument.
  1017. To only temporarily change the default device instead of setting it
  1018. globally, use ``with torch.device(device):`` instead.
  1019. The default device is initially ``cpu``. If you set the default tensor
  1020. device to another device (e.g., ``cuda``) without a device index, tensors
  1021. will be allocated on whatever the current device for the device type,
  1022. even after :func:`torch.cuda.set_device` is called.
  1023. .. warning::
  1024. This function imposes a slight performance cost on every Python
  1025. call to the torch API (not just factory functions). If this
  1026. is causing problems for you, please comment on
  1027. https://github.com/pytorch/pytorch/issues/92701
  1028. .. note::
  1029. This doesn't affect functions that create tensors that share the same memory as the input, like:
  1030. :func:`torch.from_numpy` and :func:`torch.frombuffer`
  1031. Args:
  1032. device (device or string): the device to set as default
  1033. Example::
  1034. >>> # xdoctest: +SKIP("requires cuda, changes global state")
  1035. >>> torch.get_default_device()
  1036. device(type='cpu')
  1037. >>> torch.set_default_device('cuda') # current device is 0
  1038. >>> torch.get_default_device()
  1039. device(type='cuda', index=0)
  1040. >>> torch.set_default_device('cuda')
  1041. >>> torch.cuda.set_device('cuda:1') # current device is 1
  1042. >>> torch.get_default_device()
  1043. device(type='cuda', index=1)
  1044. >>> torch.set_default_device('cuda:1')
  1045. >>> torch.get_default_device()
  1046. device(type='cuda', index=1)
  1047. """
  1048. global _GLOBAL_DEVICE_CONTEXT
  1049. if hasattr(_GLOBAL_DEVICE_CONTEXT, "device_context"):
  1050. device_context = _GLOBAL_DEVICE_CONTEXT.device_context
  1051. if device_context is not None:
  1052. device_context.__exit__(None, None, None)
  1053. if device is None:
  1054. device_context = None
  1055. else:
  1056. from torch.utils._device import DeviceContext
  1057. device_context = DeviceContext(device)
  1058. device_context.__enter__()
  1059. _GLOBAL_DEVICE_CONTEXT.device_context = device_context
  1060. def set_default_tensor_type(t: type["torch.Tensor"] | str, /) -> None:
  1061. r"""
  1062. .. warning::
  1063. This function is deprecated as of PyTorch 2.1, please use :func:`torch.set_default_dtype()` and
  1064. :func:`torch.set_default_device()` as alternatives.
  1065. Sets the default ``torch.Tensor`` type to floating point tensor type
  1066. ``t``. This type will also be used as default floating point type for
  1067. type inference in :func:`torch.tensor`.
  1068. The default floating point tensor type is initially ``torch.FloatTensor``.
  1069. Args:
  1070. t (type or string): the floating point tensor type or its name
  1071. Example::
  1072. >>> # xdoctest: +SKIP("Other tests may have changed the default type. Can we reset it?")
  1073. >>> torch.tensor([1.2, 3]).dtype # initial default for floating point is torch.float32
  1074. torch.float32
  1075. >>> torch.set_default_tensor_type(torch.DoubleTensor)
  1076. >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor
  1077. torch.float64
  1078. """
  1079. if isinstance(t, str):
  1080. t = _import_dotted_name(t)
  1081. _C._set_default_tensor_type(t)
  1082. def set_default_dtype(d: "torch.dtype", /) -> None:
  1083. r"""
  1084. Sets the default floating point dtype to :attr:`d`. Supports floating point dtype
  1085. as inputs. Other dtypes will cause torch to raise an exception.
  1086. When PyTorch is initialized its default floating point dtype is torch.float32,
  1087. and the intent of set_default_dtype(torch.float64) is to facilitate NumPy-like
  1088. type inference. The default floating point dtype is used to:
  1089. 1. Implicitly determine the default complex dtype. When the default floating type is float16,
  1090. the default complex dtype is complex32. For float32, the default complex dtype is complex64.
  1091. For float64, it is complex128. For bfloat16, an exception will be raised because
  1092. there is no corresponding complex type for bfloat16.
  1093. 2. Infer the dtype for tensors constructed using Python floats or complex Python
  1094. numbers. See examples below.
  1095. 3. Determine the result of type promotion between bool and integer tensors and
  1096. Python floats and complex Python numbers.
  1097. Args:
  1098. d (:class:`torch.dtype`): the floating point dtype to make the default.
  1099. Example:
  1100. >>> # xdoctest: +SKIP("Other tests may have changed the default type. Can we reset it?")
  1101. >>> # initial default for floating point is torch.float32
  1102. >>> # Python floats are interpreted as float32
  1103. >>> torch.tensor([1.2, 3]).dtype
  1104. torch.float32
  1105. >>> # initial default for floating point is torch.complex64
  1106. >>> # Complex Python numbers are interpreted as complex64
  1107. >>> torch.tensor([1.2, 3j]).dtype
  1108. torch.complex64
  1109. >>> torch.set_default_dtype(torch.float64)
  1110. >>> # Python floats are now interpreted as float64
  1111. >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor
  1112. torch.float64
  1113. >>> # Complex Python numbers are now interpreted as complex128
  1114. >>> torch.tensor([1.2, 3j]).dtype # a new complex tensor
  1115. torch.complex128
  1116. >>> torch.set_default_dtype(torch.float16)
  1117. >>> # Python floats are now interpreted as float16
  1118. >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor
  1119. torch.float16
  1120. >>> # Complex Python numbers are now interpreted as complex128
  1121. >>> torch.tensor([1.2, 3j]).dtype # a new complex tensor
  1122. torch.complex32
  1123. """
  1124. _C._set_default_dtype(d)
  1125. def use_deterministic_algorithms(
  1126. mode: builtins.bool,
  1127. *,
  1128. warn_only: builtins.bool = False,
  1129. ) -> None:
  1130. r"""Sets whether PyTorch operations must use "deterministic"
  1131. algorithms. That is, algorithms which, given the same input, and when
  1132. run on the same software and hardware, always produce the same output.
  1133. When enabled, operations will use deterministic algorithms when available,
  1134. and if only nondeterministic algorithms are available they will throw a
  1135. :class:`RuntimeError` when called.
  1136. .. note:: This setting alone is not always enough to make an application
  1137. reproducible. Refer to :ref:`reproducibility` for more information.
  1138. .. note:: :func:`torch.set_deterministic_debug_mode` offers an alternative
  1139. interface for this feature.
  1140. The following normally-nondeterministic operations will act
  1141. deterministically when ``mode=True``:
  1142. * :class:`torch.nn.Conv1d` when called on CUDA tensor
  1143. * :class:`torch.nn.Conv2d` when called on CUDA tensor
  1144. * :class:`torch.nn.Conv3d` when called on CUDA tensor
  1145. * :class:`torch.nn.ConvTranspose1d` when called on CUDA tensor
  1146. * :class:`torch.nn.ConvTranspose2d` when called on CUDA tensor
  1147. * :class:`torch.nn.ConvTranspose3d` when called on CUDA tensor
  1148. * :class:`torch.nn.ReplicationPad1d` when attempting to differentiate a CUDA tensor
  1149. * :class:`torch.nn.ReplicationPad2d` when attempting to differentiate a CUDA tensor
  1150. * :class:`torch.nn.ReplicationPad3d` when attempting to differentiate a CUDA tensor
  1151. * :func:`torch.bmm` when called on sparse-dense CUDA tensors
  1152. * :func:`torch.Tensor.__getitem__` when attempting to differentiate a CPU tensor
  1153. and the index is a list of tensors
  1154. * :func:`torch.Tensor.index_put` with ``accumulate=False``
  1155. * :func:`torch.Tensor.index_put` with ``accumulate=True`` when called on a CPU
  1156. tensor
  1157. * :func:`torch.Tensor.put_` with ``accumulate=True`` when called on a CPU
  1158. tensor
  1159. * :func:`torch.Tensor.scatter_add_` when called on a CUDA tensor
  1160. * :func:`torch.gather` when called on a CUDA tensor that requires grad
  1161. * :func:`torch.index_add` when called on CUDA tensor
  1162. * :func:`torch.index_select` when attempting to differentiate a CUDA tensor
  1163. * :func:`torch.repeat_interleave` when attempting to differentiate a CUDA tensor
  1164. * :func:`torch.Tensor.index_copy` when called on a CPU or CUDA tensor
  1165. * :func:`torch.Tensor.scatter` when `src` type is Tensor and called on CUDA tensor
  1166. * :func:`torch.Tensor.scatter_reduce` when ``reduce='sum'`` or ``reduce='mean'`` and called on CUDA tensor
  1167. The following normally-nondeterministic operations will throw a
  1168. :class:`RuntimeError` when ``mode=True``:
  1169. * :class:`torch.nn.AvgPool3d` when attempting to differentiate a CUDA tensor
  1170. * :class:`torch.nn.AdaptiveAvgPool2d` when attempting to differentiate a CUDA tensor
  1171. * :class:`torch.nn.AdaptiveAvgPool3d` when attempting to differentiate a CUDA tensor
  1172. * :class:`torch.nn.MaxPool3d` when attempting to differentiate a CUDA tensor
  1173. * :class:`torch.nn.AdaptiveMaxPool2d` when attempting to differentiate a CUDA tensor
  1174. * :class:`torch.nn.FractionalMaxPool2d` when attempting to differentiate a CUDA tensor
  1175. * :class:`torch.nn.FractionalMaxPool3d` when attempting to differentiate a CUDA tensor
  1176. * :class:`torch.nn.MaxUnpool1d`
  1177. * :class:`torch.nn.MaxUnpool2d`
  1178. * :class:`torch.nn.MaxUnpool3d`
  1179. * :func:`torch.nn.functional.interpolate` when attempting to differentiate a CUDA tensor
  1180. and one of the following modes is used:
  1181. - ``linear``
  1182. - ``bilinear``
  1183. - ``bicubic``
  1184. - ``trilinear``
  1185. * :class:`torch.nn.ReflectionPad1d` when attempting to differentiate a CUDA tensor
  1186. * :class:`torch.nn.ReflectionPad2d` when attempting to differentiate a CUDA tensor
  1187. * :class:`torch.nn.ReflectionPad3d` when attempting to differentiate a CUDA tensor
  1188. * :class:`torch.nn.NLLLoss` when called on a CUDA tensor
  1189. * :class:`torch.nn.CTCLoss` when attempting to differentiate a CUDA tensor
  1190. * :class:`torch.nn.EmbeddingBag` when attempting to differentiate a CUDA tensor when
  1191. ``mode='max'``
  1192. * :func:`torch.Tensor.put_` when ``accumulate=False``
  1193. * :func:`torch.Tensor.put_` when ``accumulate=True`` and called on a CUDA tensor
  1194. * :func:`torch.histc` when called on a CUDA tensor
  1195. * :func:`torch.bincount` when called on a CUDA tensor and ``weights``
  1196. tensor is given
  1197. * :func:`torch.median` with indices output when called on a CUDA tensor
  1198. * :func:`torch.nn.functional.grid_sample` when attempting to differentiate a CUDA tensor
  1199. * :func:`torch.cumsum` when called on a CUDA tensor when dtype is floating point or complex
  1200. * :func:`torch.Tensor.scatter_reduce` when ``reduce='prod'`` and called on CUDA tensor
  1201. * :func:`torch.Tensor.resize_` when called with a quantized tensor
  1202. In addition, several operations fill uninitialized memory when this setting
  1203. is turned on and when
  1204. :attr:`torch.utils.deterministic.fill_uninitialized_memory` is turned on.
  1205. See the documentation for that attribute for more information.
  1206. Note that deterministic operations tend to have worse performance than
  1207. nondeterministic operations.
  1208. When this setting is turned on, the Inductor deterministic mode is also tuned on
  1209. automatically. In deterministic mode, Inductor would avoid doing on device benchmarking
  1210. that affect numerics. This includes:
  1211. - don't pad matmul input shapes. Without enabling deterministic mode, Inductor would do
  1212. benchmarking to check if padding matmul shape is beneficial.
  1213. - don't autotune templates. Inductor has templates for kernels like matmul/conv/attention.
  1214. Without enabling deterministic mode, Inductor would do autotuning to
  1215. pick the best configs for those templates and adopt it if it's faster
  1216. than the kernel in eager mode. In deterministic mode, we pick the eager kernel.
  1217. - don't autotune triton configs for reduction. Reduction numerics are
  1218. very sensitive to triton configs. In deterministic mode, Inductor
  1219. will use some heuristics to pick the most promising configs rather
  1220. than do autotuning.
  1221. - Skip autotuning for reduction in coordinate descent tuning.
  1222. - Don't benchmarking for the computation/communication reordering pass
  1223. - Disable the feature that dynamically scale down RBLOCK triton config for higher
  1224. occupancy.
  1225. .. note::
  1226. This flag does not detect or prevent nondeterministic behavior caused
  1227. by calling an inplace operation on a tensor with an internal memory
  1228. overlap or by giving such a tensor as the :attr:`out` argument for an
  1229. operation. In these cases, multiple writes of different data may target
  1230. a single memory location, and the order of writes is not guaranteed.
  1231. Args:
  1232. mode (:class:`bool`): If True, makes potentially nondeterministic
  1233. operations switch to a deterministic algorithm or throw a runtime
  1234. error. If False, allows nondeterministic operations.
  1235. Keyword args:
  1236. warn_only (:class:`bool`, optional): If True, operations that do not
  1237. have a deterministic implementation will throw a warning instead of
  1238. an error. Default: ``False``
  1239. Example::
  1240. >>> # xdoctest: +SKIP
  1241. >>> torch.use_deterministic_algorithms(True)
  1242. # Backward mode nondeterministic error
  1243. >>> torch.nn.AvgPool3d(1)(torch.randn(3, 4, 5, 6, requires_grad=True).cuda()).sum().backward()
  1244. ...
  1245. RuntimeError: avg_pool3d_backward_cuda does not have a deterministic implementation...
  1246. """
  1247. import torch._inductor.config as inductor_config
  1248. inductor_config.deterministic = mode
  1249. _C._set_deterministic_algorithms(mode, warn_only=warn_only)
  1250. def are_deterministic_algorithms_enabled() -> builtins.bool:
  1251. r"""Returns True if the global deterministic flag is turned on. Refer to
  1252. :func:`torch.use_deterministic_algorithms` documentation for more details.
  1253. """
  1254. return _C._get_deterministic_algorithms()
  1255. def is_deterministic_algorithms_warn_only_enabled() -> builtins.bool:
  1256. r"""Returns True if the global deterministic flag is set to warn only.
  1257. Refer to :func:`torch.use_deterministic_algorithms` documentation for more
  1258. details.
  1259. """
  1260. return _C._get_deterministic_algorithms_warn_only()
  1261. def set_deterministic_debug_mode(debug_mode: builtins.int | str) -> None:
  1262. r"""Sets the debug mode for deterministic operations.
  1263. .. note:: This is an alternative interface for
  1264. :func:`torch.use_deterministic_algorithms`. Refer to that function's
  1265. documentation for details about affected operations.
  1266. Args:
  1267. debug_mode(str or int): If "default" or 0, don't error or warn on
  1268. nondeterministic operations. If "warn" or 1, warn on
  1269. nondeterministic operations. If "error" or 2, error on
  1270. nondeterministic operations.
  1271. """
  1272. # NOTE: builtins.int is used here because int in this scope resolves
  1273. # to torch.int
  1274. if not isinstance(debug_mode, (builtins.int, str)):
  1275. raise TypeError(f"debug_mode must be str or int, but got {type(debug_mode)}")
  1276. if isinstance(debug_mode, str):
  1277. if debug_mode == "default":
  1278. debug_mode = 0
  1279. elif debug_mode == "warn":
  1280. debug_mode = 1
  1281. elif debug_mode == "error":
  1282. debug_mode = 2
  1283. else:
  1284. raise RuntimeError(
  1285. "invalid value of debug_mode, expected one of `default`, "
  1286. f"`warn`, `error`, but got {debug_mode}"
  1287. )
  1288. if debug_mode == 0:
  1289. _C._set_deterministic_algorithms(False)
  1290. elif debug_mode == 1:
  1291. _C._set_deterministic_algorithms(True, warn_only=True)
  1292. elif debug_mode == 2:
  1293. _C._set_deterministic_algorithms(True)
  1294. else:
  1295. raise RuntimeError(
  1296. f"invalid value of debug_mode, expected 0, 1, or 2, but got {debug_mode}"
  1297. )
  1298. def get_deterministic_debug_mode() -> builtins.int:
  1299. r"""Returns the current value of the debug mode for deterministic
  1300. operations. Refer to :func:`torch.set_deterministic_debug_mode`
  1301. documentation for more details.
  1302. """
  1303. if _C._get_deterministic_algorithms():
  1304. if _C._get_deterministic_algorithms_warn_only():
  1305. return 1
  1306. else:
  1307. return 2
  1308. else:
  1309. return 0
  1310. def get_float32_matmul_precision() -> str:
  1311. r"""Returns the current value of float32 matrix multiplication precision. Refer to
  1312. :func:`torch.set_float32_matmul_precision` documentation for more details.
  1313. """
  1314. return _C._get_float32_matmul_precision()
  1315. def set_float32_matmul_precision(precision: str) -> None:
  1316. r"""Sets the internal precision of float32 matrix multiplications.
  1317. Running float32 matrix multiplications in lower precision may significantly increase
  1318. performance, and in some programs the loss of precision has a negligible impact.
  1319. Supports three settings:
  1320. * "highest", float32 matrix multiplications use the float32 datatype (24 mantissa
  1321. bits with 23 bits explicitly stored) for internal computations.
  1322. * "high", float32 matrix multiplications either use the TensorFloat32 datatype (10
  1323. mantissa bits explicitly stored) or treat each float32 number as the sum of two bfloat16 numbers
  1324. (approximately 16 mantissa bits with 14 bits explicitly stored), if the appropriate fast matrix multiplication
  1325. algorithms are available. Otherwise float32 matrix multiplications are computed
  1326. as if the precision is "highest". See below for more information on the bfloat16
  1327. approach.
  1328. * "medium", float32 matrix multiplications use the bfloat16 datatype (8 mantissa
  1329. bits with 7 bits explicitly stored) for internal computations, if a fast matrix multiplication algorithm
  1330. using that datatype internally is available. Otherwise float32
  1331. matrix multiplications are computed as if the precision is "high".
  1332. When using "high" precision, float32 multiplications may use a bfloat16-based algorithm
  1333. that is more complicated than simply truncating to some smaller number mantissa bits
  1334. (e.g. 10 for TensorFloat32, 7 for bfloat16 explicitly stored). Refer to [Henry2019]_ for a complete
  1335. description of this algorithm. To briefly explain here, the first step is to realize
  1336. that we can perfectly encode a single float32 number as the sum of three bfloat16
  1337. numbers (because float32 has 23 mantissa bits while bfloat16 has 7 explicitly stored, and both have the
  1338. same number of exponent bits). This means that the product of two float32 numbers can
  1339. be exactly given by the sum of nine products of bfloat16 numbers. We can then trade
  1340. accuracy for speed by dropping some of these products. The "high" precision algorithm
  1341. specifically keeps only the three most significant products, which conveniently excludes
  1342. all of the products involving the last 8 mantissa bits of either input. This means that
  1343. we can represent our inputs as the sum of two bfloat16 numbers rather than three.
  1344. Because bfloat16 fused-multiply-add (FMA) instructions are typically >10x faster than
  1345. float32 ones, it's faster to do three multiplications and 2 additions with bfloat16
  1346. precision than it is to do a single multiplication with float32 precision.
  1347. .. [Henry2019] http://arxiv.org/abs/1904.06376
  1348. .. note::
  1349. This does not change the output dtype of float32 matrix multiplications,
  1350. it controls how the internal computation of the matrix multiplication is performed.
  1351. .. note::
  1352. This does not change the precision of convolution operations. Other flags,
  1353. like `torch.backends.cudnn.allow_tf32`, may control the precision of convolution
  1354. operations.
  1355. .. note::
  1356. This flag currently only affects one native device type: CUDA.
  1357. If "high" or "medium" are set then the TensorFloat32 datatype will be used
  1358. when computing float32 matrix multiplications, equivalent to setting
  1359. `torch.backends.cuda.matmul.allow_tf32 = True`. When "highest" (the default)
  1360. is set then the float32 datatype is used for internal computations, equivalent
  1361. to setting `torch.backends.cuda.matmul.allow_tf32 = False`.
  1362. Args:
  1363. precision(str): can be set to "highest" (default), "high", or "medium" (see above).
  1364. """
  1365. _C._set_float32_matmul_precision(precision)
  1366. def set_warn_always(b: builtins.bool, /) -> None:
  1367. r"""When this flag is False (default) then some PyTorch warnings may only
  1368. appear once per process. This helps avoid excessive warning information.
  1369. Setting it to True causes these warnings to always appear, which may be
  1370. helpful when debugging.
  1371. Args:
  1372. b (:class:`bool`): If True, force warnings to always be emitted
  1373. If False, set to the default behaviour
  1374. """
  1375. _C._set_warnAlways(b)
  1376. def is_warn_always_enabled() -> builtins.bool:
  1377. r"""Returns True if the global warn_always flag is turned on. Refer to
  1378. :func:`torch.set_warn_always` documentation for more details.
  1379. """
  1380. return _C._get_warnAlways()
  1381. ################################################################################
  1382. # Define error checking functions
  1383. ################################################################################
  1384. # These error checking functions must be kept consistent with their C++
  1385. # equivalents. Their C++ equivalents are mentioned where applicable.
  1386. def _check_with(
  1387. error_type,
  1388. cond: builtins.bool | SymBool,
  1389. message: _Callable[[], str],
  1390. ): # noqa: F811
  1391. if not isinstance(cond, (builtins.bool, SymBool)):
  1392. raise TypeError(f"cond must be a bool, but got {type(cond)}")
  1393. from torch.fx.experimental.symbolic_shapes import expect_true
  1394. if expect_true(cond):
  1395. return
  1396. # error_type must be a subclass of Exception and not subclass of Warning
  1397. if not issubclass(error_type, Exception) or issubclass(error_type, Warning):
  1398. raise AssertionError(
  1399. f"error_type must be a subclass of Exception but not Warning, got {error_type}"
  1400. )
  1401. if message is None:
  1402. message_evaluated = (
  1403. "Expected cond to be True, but got False. (Could this error "
  1404. "message be improved? If so, please report an enhancement request "
  1405. "to PyTorch.)"
  1406. )
  1407. else:
  1408. if not callable(message):
  1409. raise TypeError("message must be a callable")
  1410. message_evaluated = str(message())
  1411. raise error_type(message_evaluated)
  1412. def _check(cond, message=None): # noqa: F811
  1413. r"""Throws error containing an optional message if the specified condition
  1414. is False.
  1415. Error type: ``RuntimeError``
  1416. C++ equivalent: ``TORCH_CHECK``
  1417. Args:
  1418. cond (:class:`bool`): If False, throw error
  1419. message (Callable, optional): Callable that returns either a string or
  1420. an object that has a ``__str__()`` method to be used as the error
  1421. message. Default: ``None``
  1422. """
  1423. _check_with(RuntimeError, cond, message) # pyrefly: ignore [bad-argument-type]
  1424. @_deprecated(
  1425. "_check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. \
  1426. Use _check(i >= 0) instead.",
  1427. category=FutureWarning,
  1428. )
  1429. def _check_is_size(i, message=None, *, max=None):
  1430. """Checks that a given integer is a valid size (i.e., is non-negative).
  1431. You should use this over ``_check(i >= 0)`` because it can prevent
  1432. ``GuardOnDataDependentSymNode`` exceptions by opting yourself into alternate
  1433. semantics for ``guard_size_oblivious`` tests that treat values 0 and 1
  1434. equivalently to all other values.
  1435. When max is not None, this specifies an upper bound equivalent to
  1436. ``_check(i <= max)``. This bound is also subject to alternate semantics:
  1437. in ``guard_size_oblivious`` tests, we assume that a constant max bound is
  1438. treated equivalently to all other values. Symbolic max bounds are not yet
  1439. supported.
  1440. NB: Do NOT use this in contexts where a -1 size would be valid (indicating
  1441. to infer the size from context, or if you should wrap-around or truncate).
  1442. Only use this if the only valid value is an honest to goodness size.
  1443. """
  1444. # This is responsible for the expect_true
  1445. _check(i >= 0, message)
  1446. from torch.fx.experimental.symbolic_shapes import _advise_is_size
  1447. _advise_is_size(i)
  1448. if max is not None:
  1449. _check(i <= max, message)
  1450. from torch.fx.experimental.symbolic_shapes import _advise_is_bounded
  1451. _advise_is_bounded(i, max)
  1452. def _check_index(cond, message=None): # noqa: F811
  1453. r"""Throws error containing an optional message if the specified condition
  1454. is False.
  1455. Error type: ``IndexError``
  1456. C++ equivalent: ``TORCH_CHECK_INDEX``
  1457. Args:
  1458. cond (:class:`bool`): If False, throw error
  1459. message (Callable, optional): Callable that returns either a string or
  1460. an object that has a ``__str__()`` method to be used as the error
  1461. message. Default: ``None``
  1462. """
  1463. _check_with(IndexError, cond, message) # pyrefly: ignore [bad-argument-type]
  1464. def _check_value(cond, message=None): # noqa: F811
  1465. r"""Throws error containing an optional message if the specified condition
  1466. is False.
  1467. Error type: ``ValueError``
  1468. C++ equivalent: ``TORCH_CHECK_VALUE``
  1469. Args:
  1470. cond (:class:`bool`): If False, throw error
  1471. message (Callable, optional): Callable that returns either a string or
  1472. an object that has a ``__str__()`` method to be used as the error
  1473. message. Default: ``None``
  1474. """
  1475. _check_with(ValueError, cond, message) # pyrefly: ignore [bad-argument-type]
  1476. def _check_type(cond, message=None): # noqa: F811
  1477. r"""Throws error containing an optional message if the specified condition
  1478. is False.
  1479. Error type: ``TypeError``
  1480. C++ equivalent: ``TORCH_CHECK_TYPE``
  1481. Args:
  1482. cond (:class:`bool`): If False, throw error
  1483. message (Callable, optional): Callable that returns either a string or
  1484. an object that has a ``__str__()`` method to be used as the error
  1485. message. Default: ``None``
  1486. """
  1487. _check_with(TypeError, cond, message) # pyrefly: ignore [bad-argument-type]
  1488. def _check_not_implemented(cond, message=None): # noqa: F811
  1489. r"""Throws error containing an optional message if the specified condition
  1490. is False.
  1491. Error type: ``NotImplementedError``
  1492. C++ equivalent: ``TORCH_CHECK_NOT_IMPLEMENTED``
  1493. Args:
  1494. cond (:class:`bool`): If False, throw error
  1495. message (Callable, optional): Callable that returns either a string or
  1496. an object that has a ``__str__()`` method to be used as the error
  1497. message. Default: ``None``
  1498. """
  1499. _check_with(
  1500. NotImplementedError,
  1501. cond,
  1502. # pyrefly: ignore [bad-argument-type]
  1503. message,
  1504. )
  1505. def _check_tensor_all_with(error_type, cond, message=None): # noqa: F811
  1506. if not is_tensor(cond):
  1507. raise TypeError(f"cond must be a tensor, but got {type(cond)}")
  1508. if not cond.dtype == torch.bool:
  1509. raise TypeError(f"cond tensor must have dtype torch.bool, but got {cond.dtype}")
  1510. _check_with(error_type, cond._is_all_true().item(), message) # type: ignore[arg-type]
  1511. # C++ equivalent: `TORCH_CHECK_TENSOR_ALL`
  1512. def _check_tensor_all(cond, message=None): # noqa: F811
  1513. r"""Throws error containing an optional message if the specified condition
  1514. is False.
  1515. Error type: ``RuntimeError``
  1516. C++ equivalent: ``TORCH_CHECK_TENSOR_ALL``
  1517. Args:
  1518. cond (:class:`torch.Tensor`): Tensor of dtype ``torch.bool``. If any
  1519. element is ``False``, throw error
  1520. message (Callable, optional): Callable that returns either a string or
  1521. an object that has a ``__str__()`` method to be used as the error
  1522. message. Default: ``None``
  1523. """
  1524. _check_tensor_all_with(RuntimeError, cond, message)
  1525. ################################################################################
  1526. # Define numeric constants
  1527. ################################################################################
  1528. # For Python Array API (https://data-apis.org/array-api/latest/API_specification/constants.html) and
  1529. # NumPy consistency (https://numpy.org/devdocs/reference/constants.html)
  1530. from math import e, inf, nan, pi
  1531. newaxis: None = None
  1532. __all__.extend(["e", "pi", "nan", "inf", "newaxis"])
  1533. ################################################################################
  1534. # Define Storage and Tensor classes
  1535. ################################################################################
  1536. from torch._tensor import Tensor # usort: skip
  1537. # needs to be after torch.Tensor is defined to avoid circular dependencies
  1538. from torch import storage as storage # usort: skip
  1539. from torch.storage import (
  1540. _LegacyStorage,
  1541. _StorageBase,
  1542. _warn_typed_storage_removal,
  1543. TypedStorage,
  1544. UntypedStorage,
  1545. )
  1546. # NOTE: New <type>Storage classes should never be added. When adding a new
  1547. # dtype, use torch.storage.TypedStorage directly.
  1548. class ByteStorage(_LegacyStorage):
  1549. @classproperty
  1550. def dtype(self):
  1551. _warn_typed_storage_removal(stacklevel=3)
  1552. return self._dtype
  1553. @classproperty
  1554. def _dtype(self):
  1555. return torch.uint8
  1556. class DoubleStorage(_LegacyStorage):
  1557. @classproperty
  1558. def dtype(self):
  1559. _warn_typed_storage_removal(stacklevel=3)
  1560. return self._dtype
  1561. @classproperty
  1562. def _dtype(self):
  1563. return torch.double
  1564. class FloatStorage(_LegacyStorage):
  1565. @classproperty
  1566. def dtype(self):
  1567. _warn_typed_storage_removal(stacklevel=3)
  1568. return self._dtype
  1569. @classproperty
  1570. def _dtype(self):
  1571. return torch.float
  1572. class HalfStorage(_LegacyStorage):
  1573. @classproperty
  1574. def dtype(self):
  1575. _warn_typed_storage_removal(stacklevel=3)
  1576. return self._dtype
  1577. @classproperty
  1578. def _dtype(self):
  1579. return torch.half
  1580. class LongStorage(_LegacyStorage):
  1581. @classproperty
  1582. def dtype(self):
  1583. _warn_typed_storage_removal(stacklevel=3)
  1584. return self._dtype
  1585. @classproperty
  1586. def _dtype(self):
  1587. return torch.long
  1588. class IntStorage(_LegacyStorage):
  1589. @classproperty
  1590. def dtype(self):
  1591. _warn_typed_storage_removal(stacklevel=3)
  1592. return self._dtype
  1593. @classproperty
  1594. def _dtype(self):
  1595. return torch.int
  1596. class ShortStorage(_LegacyStorage):
  1597. @classproperty
  1598. def dtype(self):
  1599. _warn_typed_storage_removal(stacklevel=3)
  1600. return self._dtype
  1601. @classproperty
  1602. def _dtype(self):
  1603. return torch.short
  1604. class CharStorage(_LegacyStorage):
  1605. @classproperty
  1606. def dtype(self):
  1607. _warn_typed_storage_removal(stacklevel=3)
  1608. return self._dtype
  1609. @classproperty
  1610. def _dtype(self):
  1611. return torch.int8
  1612. class BoolStorage(_LegacyStorage):
  1613. @classproperty
  1614. def dtype(self):
  1615. _warn_typed_storage_removal(stacklevel=3)
  1616. return self._dtype
  1617. @classproperty
  1618. def _dtype(self):
  1619. return torch.bool
  1620. class BFloat16Storage(_LegacyStorage):
  1621. @classproperty
  1622. def dtype(self):
  1623. _warn_typed_storage_removal(stacklevel=3)
  1624. return self._dtype
  1625. @classproperty
  1626. def _dtype(self):
  1627. return torch.bfloat16
  1628. class ComplexDoubleStorage(_LegacyStorage):
  1629. @classproperty
  1630. def dtype(self):
  1631. _warn_typed_storage_removal(stacklevel=3)
  1632. return self._dtype
  1633. @classproperty
  1634. def _dtype(self):
  1635. return torch.cdouble
  1636. class ComplexFloatStorage(_LegacyStorage):
  1637. @classproperty
  1638. def dtype(self):
  1639. _warn_typed_storage_removal(stacklevel=3)
  1640. return self._dtype
  1641. @classproperty
  1642. def _dtype(self):
  1643. return torch.cfloat
  1644. class QUInt8Storage(_LegacyStorage):
  1645. @classproperty
  1646. def dtype(self):
  1647. _warn_typed_storage_removal(stacklevel=3)
  1648. return self._dtype
  1649. @classproperty
  1650. def _dtype(self):
  1651. return torch.quint8
  1652. class QInt8Storage(_LegacyStorage):
  1653. @classproperty
  1654. def dtype(self):
  1655. _warn_typed_storage_removal(stacklevel=3)
  1656. return self._dtype
  1657. @classproperty
  1658. def _dtype(self):
  1659. return torch.qint8
  1660. class QInt32Storage(_LegacyStorage):
  1661. @classproperty
  1662. def dtype(self):
  1663. _warn_typed_storage_removal(stacklevel=3)
  1664. return self._dtype
  1665. @classproperty
  1666. def _dtype(self):
  1667. return torch.qint32
  1668. class QUInt4x2Storage(_LegacyStorage):
  1669. @classproperty
  1670. def dtype(self):
  1671. _warn_typed_storage_removal(stacklevel=3)
  1672. return self._dtype
  1673. @classproperty
  1674. def _dtype(self):
  1675. return torch.quint4x2
  1676. class QUInt2x4Storage(_LegacyStorage):
  1677. @classproperty
  1678. def dtype(self):
  1679. _warn_typed_storage_removal(stacklevel=3)
  1680. return self._dtype
  1681. @classproperty
  1682. def _dtype(self):
  1683. return torch.quint2x4
  1684. _storage_classes: set[type[TypedStorage | UntypedStorage]] = {
  1685. UntypedStorage,
  1686. DoubleStorage,
  1687. FloatStorage,
  1688. LongStorage,
  1689. IntStorage,
  1690. ShortStorage,
  1691. CharStorage,
  1692. ByteStorage,
  1693. HalfStorage,
  1694. BoolStorage,
  1695. QUInt8Storage,
  1696. QInt8Storage,
  1697. QInt32Storage,
  1698. BFloat16Storage,
  1699. ComplexFloatStorage,
  1700. ComplexDoubleStorage,
  1701. QUInt4x2Storage,
  1702. QUInt2x4Storage,
  1703. TypedStorage,
  1704. }
  1705. # The _tensor_classes set is initialized by the call to initialize_python_bindings.
  1706. _tensor_classes: set[type["torch.Tensor"]] = set()
  1707. # If you edit these imports, please update torch/__init__.py.in as well
  1708. from torch import amp as amp, random as random, serialization as serialization
  1709. from torch._tensor_str import set_printoptions
  1710. from torch.amp import autocast, GradScaler
  1711. from torch.random import get_rng_state, initial_seed, manual_seed, seed, set_rng_state
  1712. from torch.serialization import load, save
  1713. ################################################################################
  1714. # Initialize extension
  1715. ################################################################################
  1716. # Shared memory manager needs to know the exact location of manager executable
  1717. def _manager_path():
  1718. if platform.system() == "Windows":
  1719. return b""
  1720. path = get_file_path("torch", "bin", "torch_shm_manager")
  1721. prepare_multiprocessing_environment(get_file_path("torch"))
  1722. if not os.path.exists(path):
  1723. raise RuntimeError("Unable to find torch_shm_manager at " + path)
  1724. return path.encode("utf-8")
  1725. _C._initExtension(_manager_path())
  1726. del _manager_path
  1727. # Appease the type checker: it can't deal with direct setting of globals().
  1728. # Note that we will see "too many" functions when reexporting this way; there
  1729. # is not a good way to fix this problem. Perhaps, try to redesign VariableFunctions
  1730. # so that this import is good enough
  1731. if TYPE_CHECKING:
  1732. # Some type signatures pulled in from _VariableFunctions here clash with
  1733. # signatures already imported. For now these clashes are ignored; see
  1734. # PR #43339 for details.
  1735. from torch._C._VariableFunctions import * # type: ignore[assignment, misc] # noqa: F403
  1736. # Fixup segment_reduce visibility
  1737. _segment_reduce = segment_reduce
  1738. del segment_reduce # noqa: F821
  1739. # Ops not to be exposed in `torch` namespace,
  1740. # mostly helper ops.
  1741. PRIVATE_OPS = ("unique_dim",)
  1742. __name, __obj = "", None
  1743. for __name in dir(_C._VariableFunctions):
  1744. if __name.startswith("__") or __name in PRIVATE_OPS:
  1745. continue
  1746. __obj = getattr(_C._VariableFunctions, __name)
  1747. __obj.__module__ = __name__ # "torch"
  1748. # Hide some APIs that should not be public
  1749. if __name == "segment_reduce":
  1750. # TODO: Once the undocumented FC window is passed, remove the line below
  1751. globals()[__name] = __obj
  1752. __name = "_" + __name
  1753. globals()[__name] = __obj
  1754. if not __name.startswith("_"):
  1755. __all__.append(__name)
  1756. del __name, __obj
  1757. ################################################################################
  1758. # Add torch.dtype instances to the public API
  1759. ################################################################################
  1760. import torch
  1761. __all__.extend(
  1762. name for name in dir(torch) if isinstance(getattr(torch, name), torch.dtype)
  1763. )
  1764. ################################################################################
  1765. # Import TorchDynamo's lazy APIs to avoid circular dependencies
  1766. ################################################################################
  1767. # needs to be before from torch.functional import * to avoid circular dependencies
  1768. from torch._compile import _disable_dynamo # usort: skip
  1769. ################################################################################
  1770. # Import interface functions defined in Python
  1771. ################################################################################
  1772. # needs to be after the above ATen bindings so we can overwrite from Python side
  1773. from torch import _VF as _VF, functional as functional # usort: skip
  1774. from torch.functional import * # usort: skip # noqa: F403
  1775. ################################################################################
  1776. # Remove unnecessary members
  1777. ################################################################################
  1778. del _StorageBase
  1779. del _LegacyStorage
  1780. ################################################################################
  1781. # Define _assert
  1782. ################################################################################
  1783. # needs to be before the submodule imports to avoid circular dependencies
  1784. def _assert(condition, message):
  1785. r"""A wrapper around Python's assert which is symbolically traceable."""
  1786. if type(condition) is not torch.Tensor and overrides.has_torch_function(
  1787. (condition,)
  1788. ):
  1789. return overrides.handle_torch_function(
  1790. _assert, (condition,), condition, message
  1791. )
  1792. if not condition:
  1793. raise AssertionError(message)
  1794. ################################################################################
  1795. # Import most common subpackages
  1796. ################################################################################
  1797. # Use the redundant form so that type checkers know that these are a part of
  1798. # the public API. The "regular" import lines are there solely for the runtime
  1799. # side effect of adding to the imported module's members for other users.
  1800. # needs to be before import torch.nn as nn to avoid circular dependencies
  1801. from torch.autograd import ( # usort: skip
  1802. enable_grad as enable_grad,
  1803. inference_mode as inference_mode,
  1804. no_grad as no_grad,
  1805. set_grad_enabled as set_grad_enabled,
  1806. )
  1807. from torch import (
  1808. __config__ as __config__,
  1809. __future__ as __future__,
  1810. _awaits as _awaits,
  1811. accelerator as accelerator,
  1812. autograd as autograd,
  1813. backends as backends,
  1814. cpu as cpu,
  1815. cuda as cuda,
  1816. distributed as distributed,
  1817. distributions as distributions,
  1818. fft as fft,
  1819. futures as futures,
  1820. hub as hub,
  1821. jit as jit,
  1822. linalg as linalg,
  1823. mps as mps,
  1824. mtia as mtia,
  1825. multiprocessing as multiprocessing,
  1826. nested as nested,
  1827. nn as nn,
  1828. optim as optim,
  1829. overrides as overrides,
  1830. profiler as profiler,
  1831. sparse as sparse,
  1832. special as special,
  1833. testing as testing,
  1834. types as types,
  1835. utils as utils,
  1836. version as version,
  1837. xpu as xpu,
  1838. )
  1839. from torch.signal import windows as windows
  1840. # Quantized, sparse, AO, etc. should be last to get imported, as nothing
  1841. # is expected to depend on them.
  1842. from torch import ao as ao # usort: skip
  1843. # nn.quant* depends on ao -- so should be after those.
  1844. import torch.nn.intrinsic
  1845. import torch.nn.qat
  1846. import torch.nn.quantizable
  1847. import torch.nn.quantized
  1848. _C._init_names(list(_storage_classes))
  1849. # attach docstrings to torch and tensor functions
  1850. from torch import _size_docs, _storage_docs, _tensor_docs, _torch_docs
  1851. del _torch_docs, _tensor_docs, _storage_docs, _size_docs
  1852. def compiled_with_cxx11_abi() -> builtins.bool:
  1853. r"""Returns whether PyTorch was built with _GLIBCXX_USE_CXX11_ABI=1"""
  1854. return True
  1855. from torch import _library as _library, _ops as _ops
  1856. # Import the ops and classes "namespace"
  1857. from torch._ops import ops as ops # usort: skip
  1858. from torch._classes import classes as classes # usort: skip
  1859. sys.modules.setdefault(f"{__name__}.ops", ops)
  1860. sys.modules.setdefault(f"{__name__}.classes", classes)
  1861. # quantization depends on torch.fx and torch.ops
  1862. # Import quantization
  1863. from torch import quantization as quantization # usort: skip
  1864. # Import the quasi random sampler
  1865. from torch import quasirandom as quasirandom # usort: skip
  1866. # If you are seeing this, it means that this call site was not checked if
  1867. # the memory format could be preserved, and it was switched to old default
  1868. # behaviour of contiguous
  1869. legacy_contiguous_format = contiguous_format # defined by _C._initExtension()
  1870. # Register fork handler to initialize OpenMP in child processes (see gh-28389)
  1871. from torch.multiprocessing._atfork import register_after_fork
  1872. register_after_fork(torch.get_num_threads)
  1873. del register_after_fork
  1874. # Import tools that require fully imported torch (for applying
  1875. # torch.jit.script as a decorator, for instance):
  1876. from torch._lobpcg import lobpcg as lobpcg
  1877. # These were previously defined in native_functions.yaml and appeared on the
  1878. # `torch` namespace, but we moved them to c10 dispatch to facilitate custom
  1879. # class usage. We add these lines here to preserve backward compatibility.
  1880. quantized_lstm = ops.aten.quantized_lstm
  1881. quantized_gru = ops.aten.quantized_gru
  1882. # Import experimental masked operations support. See
  1883. # [RFC-0016](https://github.com/pytorch/rfcs/pull/27) for more
  1884. # information.
  1885. from torch import masked as masked
  1886. # Import removed ops with error message about removal
  1887. from torch._linalg_utils import ( # type: ignore[misc]
  1888. _symeig as symeig,
  1889. eig,
  1890. lstsq,
  1891. matrix_rank,
  1892. solve,
  1893. )
  1894. from torch.utils.dlpack import from_dlpack, to_dlpack
  1895. class _TorchCompileInductorWrapper:
  1896. compiler_name = "inductor"
  1897. def __init__(self, mode, options, dynamic):
  1898. from torch._inductor.compiler_bisector import CompilerBisector
  1899. self.config: dict[str, _Any] = {}
  1900. self.dynamic = dynamic
  1901. self.apply_mode(mode)
  1902. self.apply_options(options)
  1903. self.apply_options(CompilerBisector.get_config_change("inductor"))
  1904. cuda_version = None
  1905. if hasattr(torch, "version"):
  1906. from torch.torch_version import TorchVersion
  1907. cuda_version = TorchVersion(getattr(torch.version, "cuda", "0.0"))
  1908. if self.config.get("triton.cudagraphs", False) and (
  1909. (cuda_version and cuda_version < "12.6")
  1910. or not profiler_allow_cudagraph_cupti_lazy_reinit_cuda12()
  1911. ):
  1912. os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1"
  1913. # FIXME: CUDA Graph does not work well with CUPTI teardown.
  1914. # 1) crashes on 1st lazy CUPTI re-init after teardown (CUDA 11)
  1915. # 2) crashes on 2nd non-lazy CUPTI re-init after teardown (CUDA 12)
  1916. # Workaround: turn off CUPTI teardown when using CUDA Graphs.
  1917. os.environ["TEARDOWN_CUPTI"] = "0"
  1918. def __eq__(self, other):
  1919. return (
  1920. isinstance(other, _TorchCompileInductorWrapper)
  1921. and self.config == other.config
  1922. and self.dynamic == other.dynamic
  1923. )
  1924. def apply_mode(self, mode: str | None):
  1925. if mode and mode != "default":
  1926. from torch._inductor import list_mode_options
  1927. self.apply_options(list_mode_options(mode, self.dynamic))
  1928. def apply_options(self, options: dict[str, _Any] | None):
  1929. if not options:
  1930. return
  1931. from torch._inductor import config
  1932. current_config: dict[str, _Any] = config.get_config_copy()
  1933. for key, val in options.items():
  1934. attr_name = key.replace("-", "_")
  1935. if attr_name not in current_config:
  1936. raise RuntimeError(
  1937. f"Unexpected optimization option {key}, known options are {list(current_config.keys())}"
  1938. )
  1939. attr_type = config.get_type(attr_name) # type: ignore[attr-defined]
  1940. # Subscriptable generic types don't support isinstance so skip the type
  1941. # check. There doesn't seem to be a good way of checking membership without
  1942. # 3rd party libraries.
  1943. if _get_origin(attr_type) is None:
  1944. if not isinstance(val, attr_type):
  1945. val_type_str = type(val).__name__
  1946. expected_type_str = type(current_config[attr_name]).__name__
  1947. raise RuntimeError(
  1948. f"Unexpected type of attr {key}, got {val_type_str} should be {expected_type_str}"
  1949. )
  1950. self.config[attr_name] = val
  1951. def __call__(self, model_, inputs_):
  1952. from torch._inductor.compile_fx import compile_fx
  1953. return compile_fx(model_, inputs_, config_patches=self.config)
  1954. def get_compiler_config(self):
  1955. from torch._inductor.compile_fx import get_patched_config_dict
  1956. return get_patched_config_dict(config_patches=self.config)
  1957. def reset(self):
  1958. from torch._inductor import config
  1959. if "triton.cudagraphs" in self.config or config.triton.cudagraphs:
  1960. if self.config.get("triton.cudagraphs", True):
  1961. from torch._inductor.cudagraph_trees import reset_cudagraph_trees
  1962. reset_cudagraph_trees()
  1963. class _TorchCompileAOTInductorWrapper(_TorchCompileInductorWrapper):
  1964. compiler_name = "aotinductor"
  1965. def __init__(self, mode, options, dynamic):
  1966. super().__init__(mode, options, dynamic)
  1967. self.apply_options({"cpp_wrapper": True})
  1968. self.apply_options({"aot_inductor.package": True})
  1969. def __call__(self, model_, inputs_):
  1970. from contextlib import nullcontext
  1971. from unittest import mock
  1972. from torch._guards import detect_fake_mode
  1973. from torch._inductor.virtualized import V
  1974. fake_mode = detect_fake_mode(inputs_)
  1975. ctx = (
  1976. mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
  1977. if fake_mode
  1978. else nullcontext()
  1979. )
  1980. with (
  1981. V.set_aot_compilation(True),
  1982. ctx,
  1983. torch._inductor.config.patch("enable_autograd_for_aot", True),
  1984. ):
  1985. return super().__call__(model_, inputs_)
  1986. class _TorchCompileWrapper:
  1987. def __init__(self, backend, mode, options, dynamic):
  1988. from torch._dynamo.backends.registry import lookup_backend
  1989. if isinstance(backend, str):
  1990. self.compiler_name = backend
  1991. elif hasattr(backend, "__name__"):
  1992. self.compiler_name = backend.__name__
  1993. else:
  1994. self.compiler_name = str(backend)
  1995. self.dynamic = dynamic
  1996. self.compiler_fn = lookup_backend(backend)
  1997. self.kwargs = {}
  1998. # only pass the args if they non-empty
  1999. if mode and mode != "default":
  2000. self.kwargs["mode"] = mode
  2001. if options:
  2002. self.kwargs["options"] = options
  2003. def __eq__(self, other):
  2004. return (
  2005. isinstance(other, _TorchCompileWrapper)
  2006. and self.compiler_fn == other.compiler_fn
  2007. and self.kwargs == other.kwargs
  2008. and self.dynamic == other.dynamic
  2009. )
  2010. def __call__(self, model_, inputs_):
  2011. return self.compiler_fn(model_, inputs_, **self.kwargs)
  2012. def reset(self):
  2013. if hasattr(self.compiler_fn, "reset"):
  2014. self.compiler_fn.reset()
  2015. _InputT = _ParamSpec("_InputT")
  2016. _RetT = _TypeVar("_RetT")
  2017. @_overload
  2018. def compile(
  2019. model: _Callable[_InputT, _RetT],
  2020. *,
  2021. fullgraph: builtins.bool = False,
  2022. dynamic: builtins.bool | None = None,
  2023. backend: str | _Callable = "inductor",
  2024. mode: str | None = None,
  2025. options: dict[str, str | builtins.int | builtins.bool | _Callable] | None = None,
  2026. disable: builtins.bool = False,
  2027. ) -> _Callable[_InputT, _RetT]: ...
  2028. @_overload
  2029. def compile(
  2030. model: None = None,
  2031. *,
  2032. fullgraph: builtins.bool = False,
  2033. dynamic: builtins.bool | None = None,
  2034. backend: str | _Callable = "inductor",
  2035. mode: str | None = None,
  2036. options: dict[str, str | builtins.int | builtins.bool | _Callable] | None = None,
  2037. disable: builtins.bool = False,
  2038. ) -> _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]]: ...
  2039. def compile(
  2040. model: _Callable[_InputT, _RetT] | None = None,
  2041. *,
  2042. fullgraph: builtins.bool = False,
  2043. dynamic: builtins.bool | None = None,
  2044. backend: str | _Callable = "inductor",
  2045. mode: str | None = None,
  2046. options: dict[str, str | builtins.int | builtins.bool | _Callable] | None = None,
  2047. disable: builtins.bool = False,
  2048. ) -> (
  2049. _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]]
  2050. | _Callable[_InputT, _RetT]
  2051. ):
  2052. """
  2053. Optimizes given model/function using TorchDynamo and specified backend.
  2054. If you are compiling an :class:`torch.nn.Module`, you can also use :meth:`torch.nn.Module.compile`
  2055. to compile the module inplace without changing its structure.
  2056. Concretely, for every frame executed within the compiled region, we will attempt
  2057. to compile it and cache the compiled result on the code object for future
  2058. use. A single frame may be compiled multiple times if previous compiled
  2059. results are not applicable for subsequent calls (this is called a "guard
  2060. failure"), you can use TORCH_LOGS=guards to debug these situations.
  2061. Multiple compiled results can be associated with a frame up to
  2062. ``torch._dynamo.config.recompile_limit``, which defaults to 8; at which
  2063. point we will fall back to eager. Note that compile caches are per
  2064. *code object*, not frame; if you dynamically create multiple copies of a
  2065. function, they will all share the same code cache.
  2066. Args:
  2067. model (Callable or None): Module/function to optimize
  2068. fullgraph (bool): If False (default), torch.compile attempts to discover compilable regions
  2069. in the function that it will optimize. If True, then we require that the entire function be
  2070. capturable into a single graph. If this is not possible (that is, if there are graph breaks),
  2071. then this will raise an error. This also opts into unbacked semantics, notably it will turn on
  2072. capture_scalar_outputs and capture_dynamic_output_shape_ops on by default.
  2073. dynamic (bool or None): Use dynamic shape tracing. When this is True, we will up-front attempt
  2074. to generate a kernel that is as dynamic as possible to avoid recompilations when
  2075. sizes change. This may not always work as some operations/optimizations will
  2076. force specialization; use TORCH_LOGS=dynamic to debug overspecialization.
  2077. When this is False, we will NEVER generate dynamic kernels, we will always specialize.
  2078. By default (None), we automatically detect if dynamism has occurred and compile a more
  2079. dynamic kernel upon recompile.
  2080. backend (str or Callable): backend to be used
  2081. - "inductor" is the default backend, which is a good balance between performance and overhead
  2082. - Non experimental in-tree backends can be seen with `torch._dynamo.list_backends()`
  2083. - Experimental or debug in-tree backends can be seen with `torch._dynamo.list_backends(None)`
  2084. - To register an out-of-tree custom backend:
  2085. https://docs.pytorch.org/docs/main/user_guide/torch_compiler/torch.compiler_custom_backends.html#registering-custom-backends
  2086. mode (str): Can be either "default", "reduce-overhead", "max-autotune" or "max-autotune-no-cudagraphs"
  2087. - "default" is the default mode, which is a good balance between performance and overhead
  2088. - "reduce-overhead" is a mode that reduces the overhead of python with CUDA graphs,
  2089. useful for small batches. Reduction of overhead can come at the cost of more memory
  2090. usage, as we will cache the workspace memory required for the invocation so that we
  2091. do not have to reallocate it on subsequent runs. Reduction of overhead is not guaranteed
  2092. to work; today, we only reduce overhead for CUDA only graphs which do not mutate inputs.
  2093. There are other circumstances where CUDA graphs are not applicable; use TORCH_LOGS=perf_hints
  2094. to debug.
  2095. - "max-autotune" is a mode that leverages Triton or template based matrix multiplications
  2096. on supported devices and Triton based convolutions on GPU.
  2097. It enables CUDA graphs by default on GPU.
  2098. - "max-autotune-no-cudagraphs" is a mode similar to "max-autotune" but without CUDA graphs
  2099. - To see the exact configs that each mode sets you can call `torch._inductor.list_mode_options()`
  2100. options (dict): A dictionary of options to pass to the backend. Some notable ones to try out are
  2101. - `epilogue_fusion` which fuses pointwise ops into templates. Requires `max_autotune` to also be set
  2102. - `max_autotune` which will profile to pick the best matmul configuration
  2103. - `fallback_random` which is useful when debugging accuracy issues
  2104. - `shape_padding` which pads matrix shapes to better align loads on GPUs especially for tensor cores
  2105. - `triton.cudagraphs` which will reduce the overhead of python with CUDA graphs
  2106. - `trace.enabled` which is the most useful debugging flag to turn on
  2107. - `trace.graph_diagram` which will show you a picture of your graph after fusion
  2108. - `guard_filter_fn` that controls which dynamo guards are saved with compilations.
  2109. This is an unsafe feature and there is no backward compatibility guarantee provided
  2110. for dynamo guards as data types.
  2111. For stable helper functions to use, see the documentations in `torch.compiler`, for example:
  2112. - `torch.compiler.skip_guard_on_inbuilt_nn_modules_unsafe`
  2113. - `torch.compiler.skip_guard_on_all_nn_modules_unsafe`
  2114. - `torch.compiler.keep_tensor_guards_unsafe`
  2115. - For inductor you can see the full list of configs that it supports by calling `torch._inductor.list_options()`
  2116. disable (bool): Turn torch.compile() into a no-op for testing
  2117. Example::
  2118. @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
  2119. def foo(x):
  2120. return torch.sin(x) + torch.cos(x)
  2121. """
  2122. import sysconfig
  2123. _C._log_api_usage_once("torch.compile")
  2124. if sys.version_info >= (3, 15):
  2125. raise RuntimeError("torch.compile is not supported on Python 3.15+")
  2126. elif sysconfig.get_config_var("Py_GIL_DISABLED") == 1 and sys.version_info < (
  2127. 3,
  2128. 13,
  2129. 3,
  2130. ):
  2131. raise RuntimeError(
  2132. "torch.compile is not supported on Python < 3.13.3 built with GIL disabled. "
  2133. "Please use Python 3.13.3+."
  2134. )
  2135. # Decorator mode
  2136. if model is None:
  2137. def fn(model: _Callable[_InputT, _RetT]) -> _Callable[_InputT, _RetT]:
  2138. if model is None:
  2139. raise RuntimeError("Model can't be None")
  2140. return compile( # pyrefly: ignore # no-matching-overload
  2141. model,
  2142. fullgraph=fullgraph,
  2143. dynamic=dynamic,
  2144. backend=backend,
  2145. mode=mode,
  2146. options=options,
  2147. disable=disable,
  2148. )
  2149. return fn
  2150. if mode is not None and options is not None:
  2151. raise RuntimeError(
  2152. "Either mode or options can be specified, but both can't be specified at the same time."
  2153. )
  2154. if mode is None and options is None:
  2155. mode = "default"
  2156. from torch._inductor.compiler_bisector import CompilerBisector
  2157. if bisect_backend := CompilerBisector.get_backend():
  2158. import torch._inductor.config as inductor_config
  2159. # don't override the backend for use cases like vllm
  2160. # which leverages their custom backend.
  2161. if not (
  2162. inductor_config.test_configs.bisect_keep_custom_backend_for_inductor
  2163. and bisect_backend == "inductor"
  2164. and not isinstance(backend, str)
  2165. ):
  2166. backend = bisect_backend
  2167. guard_filter_fn = None
  2168. use_aoti = False
  2169. if options and isinstance(options, dict):
  2170. guard_filter_fn = options.pop("guard_filter_fn", None)
  2171. use_aoti = options.pop("use_aoti", False)
  2172. if torch.compiler.is_exporting():
  2173. from torch._higher_order_ops.utils import _in_hop_compile
  2174. if not _in_hop_compile():
  2175. warnings.warn(
  2176. "torch.compile is ignored when called inside torch.export region",
  2177. stacklevel=2,
  2178. )
  2179. # torch.compile is a no-op when inside torch.export region
  2180. return model
  2181. if backend == "inductor":
  2182. if use_aoti:
  2183. backend = _TorchCompileAOTInductorWrapper(mode, options, dynamic)
  2184. else:
  2185. backend = _TorchCompileInductorWrapper(mode, options, dynamic)
  2186. else:
  2187. backend = _TorchCompileWrapper(backend, mode, options, dynamic)
  2188. return torch._dynamo.optimize(
  2189. backend=backend,
  2190. nopython=fullgraph,
  2191. dynamic=dynamic,
  2192. disable=disable,
  2193. guard_filter_fn=guard_filter_fn,
  2194. )(model) # type: ignore[return-value]
  2195. def _register_device_module(device_type, module):
  2196. r"""Register an external runtime module of the specific :attr:`device_type`
  2197. supported by torch.
  2198. After the :attr:`module` is registered correctly, the user can refer
  2199. the external runtime module as part of torch with attribute torch.xxx.
  2200. """
  2201. # Make sure the device_type represent a supported device type for torch.
  2202. device_type = torch.device(device_type).type
  2203. m = sys.modules[__name__]
  2204. if hasattr(m, device_type):
  2205. raise RuntimeError(
  2206. f"The runtime module of '{device_type}' has already "
  2207. f"been registered with '{getattr(m, device_type)}'"
  2208. )
  2209. setattr(m, device_type, module)
  2210. torch_module_name = ".".join([__name__, device_type])
  2211. sys.modules[torch_module_name] = module
  2212. from torch import (
  2213. export as export,
  2214. func as func,
  2215. library as library,
  2216. return_types as return_types,
  2217. )
  2218. from torch._higher_order_ops import cond as cond, while_loop as while_loop
  2219. from torch.func import vmap as vmap
  2220. if not TYPE_CHECKING:
  2221. # register python metas for distributed ops
  2222. # Only import if distributed is available (USE_DISTRIBUTED=1)
  2223. if hasattr(torch._C, "_c10d_init"):
  2224. import torch.distributed._meta_registrations as coll_meta_registrations
  2225. del coll_meta_registrations
  2226. from torch import _meta_registrations
  2227. # Enable CUDA Sanitizer
  2228. if "TORCH_CUDA_SANITIZER" in os.environ:
  2229. import torch.cuda._sanitizer as csan
  2230. csan.enable_cuda_sanitizer()
  2231. # Populate magic methods on SymInt and SymFloat
  2232. import torch.fx.experimental.sym_node
  2233. from torch import fx as fx
  2234. # Register MPS specific decomps
  2235. torch.backends.mps._init()
  2236. from torch import compiler as compiler
  2237. class _TritonLibrary:
  2238. lib = torch.library.Library("triton", "DEF")
  2239. ops_table: dict[tuple[str, str], _Callable] = {}
  2240. @classmethod
  2241. def registerOp(cls, op_key, full_schema, op_impl, dispatch_key):
  2242. if (op_key, dispatch_key) not in cls.ops_table:
  2243. cls.lib.define(full_schema)
  2244. cls.lib.impl("triton::" + op_key, op_impl, dispatch_key)
  2245. cls.ops_table[(op_key, dispatch_key)] = op_impl
  2246. return cls.ops_table[(op_key, dispatch_key)]
  2247. # Deprecated attributes
  2248. _deprecated_attrs = {
  2249. "has_mps": torch.backends.mps.is_built,
  2250. "has_cuda": torch.backends.cuda.is_built,
  2251. "has_cudnn": torch.backends.cudnn.is_available,
  2252. "has_mkldnn": torch.backends.mkldnn.is_available,
  2253. }
  2254. if TYPE_CHECKING:
  2255. # Import the following modules during type checking to enable code intelligence features,
  2256. # such as auto-completion in tools like pylance, even when these modules are not explicitly
  2257. # imported in user code.
  2258. from torch import (
  2259. _dynamo as _dynamo,
  2260. _inductor as _inductor,
  2261. _subclasses as _subclasses,
  2262. onnx as onnx,
  2263. )
  2264. else:
  2265. _lazy_modules = {
  2266. "_dynamo",
  2267. "_inductor",
  2268. "_export",
  2269. # ONNX must be imported after _dynamo, _ops, _subclasses, fx, func and jit
  2270. "onnx",
  2271. }
  2272. def __getattr__(name):
  2273. # Deprecated attrs
  2274. replacement = _deprecated_attrs.get(name)
  2275. if replacement is not None:
  2276. import warnings
  2277. warnings.warn(
  2278. f"'{name}' is deprecated, please use '{replacement.__module__}.{replacement.__name__}()'",
  2279. stacklevel=2,
  2280. )
  2281. return replacement()
  2282. # Lazy modules
  2283. if name in _lazy_modules:
  2284. return importlib.import_module(f".{name}", __name__)
  2285. raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
  2286. @functools.cache
  2287. def get_device_module(device: torch.device | str | None = None):
  2288. """
  2289. Returns the module associated with a given device(e.g., torch.device('cuda'), "mtia:0", "xpu", ...).
  2290. If no device is given, return the module for the current accelerator or CPU if none is present.
  2291. """
  2292. if isinstance(device, torch.device):
  2293. device_module_name = device.type
  2294. elif isinstance(device, str):
  2295. device_module_name = torch.device(device).type
  2296. elif device is None:
  2297. # Using default accelerator type. If no accelerator is available, it automatically returns CPU device.
  2298. device_module_name = torch._C._get_accelerator().type
  2299. else:
  2300. raise RuntimeError(
  2301. f"Invalid value of device '{device}', expect torch.device, str, or None"
  2302. )
  2303. device_module = getattr(torch, device_module_name, None)
  2304. if device_module is None:
  2305. raise RuntimeError(
  2306. f"Device '{device_module_name}' does not have a corresponding module registered as 'torch.{device_module_name}'."
  2307. )
  2308. return device_module
  2309. def _constrain_as_size(
  2310. symbol,
  2311. min: builtins.int | None = None,
  2312. max: builtins.int | None = None,
  2313. ):
  2314. """
  2315. This indicates that a given int is size-like, and can be used in any context where a size is expected.
  2316. You will typically use this when reading out integers from Tensors, e.g., max.item() or lengths.tolist()
  2317. which then need to be used as tensor constructors. Providing these assertions to PyTorch can help resolve
  2318. GuardOnDataDependentSymNode errors upon export, since we cannot guard on unbacked SymInts.
  2319. This function has unusual semantics in some circumstances in framework
  2320. code, we will treat this int as >= 2 (when we do a size-oblivious guard).
  2321. This makes it easier to use the unbacked int in size contexts,
  2322. as we will often attempt to guard on a size being zero/one
  2323. (e.g., when computing the contiguity of a tensor, or testing if
  2324. broadcasting can occur), which will not work on unbacked SymInts.
  2325. However, if we conservatively assume that the size is not zero/one, we will
  2326. end up with a graph that will still work even if the size is zero/one.
  2327. For more details, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit
  2328. ```
  2329. """
  2330. torch.sym_constrain_range_for_size(symbol, min=min, max=max)
  2331. from torch import _logging
  2332. _logging._init_logs()
  2333. def _import_device_backends():
  2334. """
  2335. Leverage the Python plugin mechanism to load out-of-the-tree device extensions.
  2336. See this RFC: https://github.com/pytorch/pytorch/issues/122468
  2337. """
  2338. from importlib.metadata import entry_points
  2339. group_name = "torch.backends"
  2340. backend_extensions = entry_points(group=group_name)
  2341. for backend_extension in backend_extensions:
  2342. try:
  2343. # Load the extension
  2344. entrypoint = backend_extension.load()
  2345. # Call the entrypoint
  2346. entrypoint()
  2347. except Exception as err:
  2348. raise RuntimeError(
  2349. f"Failed to load the backend extension: {backend_extension.name}. "
  2350. f"You can disable extension auto-loading with TORCH_DEVICE_BACKEND_AUTOLOAD=0."
  2351. ) from err
  2352. def _is_device_backend_autoload_enabled() -> builtins.bool:
  2353. """
  2354. Whether autoloading out-of-the-tree device extensions is enabled.
  2355. The switch depends on the value of the environment variable
  2356. `TORCH_DEVICE_BACKEND_AUTOLOAD`.
  2357. Returns:
  2358. bool: Whether to enable autoloading the extensions. Enabled by default.
  2359. Examples:
  2360. >>> torch._is_device_backend_autoload_enabled()
  2361. True
  2362. """
  2363. # enabled by default
  2364. return os.getenv("TORCH_DEVICE_BACKEND_AUTOLOAD", "1") == "1"
  2365. def _as_tensor_fullprec(t):
  2366. """
  2367. Like torch.as_tensor, but when given Python data types it will keep
  2368. them in full precision. Used for calling convention for Dynamo.
  2369. Python scalars (float, int) are always created on CPU to avoid being
  2370. affected by DeviceContext.
  2371. """
  2372. ty = type(t)
  2373. if ty is builtins.float:
  2374. return torch.as_tensor(t, dtype=torch.float64, device="cpu")
  2375. elif ty is builtins.int:
  2376. return torch.as_tensor(t, dtype=torch.int64, device="cpu")
  2377. else:
  2378. return torch.as_tensor(t)
  2379. # `_import_device_backends` should be kept at the end to ensure
  2380. # all the other functions in this module that may be accessed by
  2381. # an autoloaded backend are defined
  2382. if _is_device_backend_autoload_enabled():
  2383. _import_device_backends()