| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001 |
- """
- The torch package contains data structures for multi-dimensional
- tensors and defines mathematical operations over these tensors.
- Additionally, it provides many utilities for efficient serialization of
- Tensors and arbitrary types, and other useful utilities.
- It has a CUDA counterpart, that enables you to run your tensor computations
- on an NVIDIA GPU with compute capability >= 3.0.
- """
- # mypy: allow-untyped-defs
- import builtins
- import ctypes
- import functools
- import glob
- import importlib
- import inspect
- import math
- import os
- import platform
- import sys
- import textwrap
- import threading
- import warnings
- from collections.abc import Callable as _Callable
- from typing import (
- Any as _Any,
- get_origin as _get_origin,
- overload as _overload,
- TYPE_CHECKING,
- TypeVar as _TypeVar,
- )
- from typing_extensions import (
- deprecated as _deprecated,
- ParamSpec as _ParamSpec,
- TypeIs as _TypeIs,
- )
- # As a bunch of torch.packages internally still have this check
- # we need to keep this. @todo: Remove tests that rely on this check as
- # they are likely stale.
- def _running_with_deploy() -> builtins.bool:
- return False
- from torch._utils import (
- _functionalize_sync as _sync,
- _import_dotted_name,
- classproperty,
- )
- from torch._utils_internal import (
- get_file_path,
- prepare_multiprocessing_environment,
- profiler_allow_cudagraph_cupti_lazy_reinit_cuda12,
- USE_GLOBAL_DEPS,
- USE_RTLD_GLOBAL_WITH_LIBTORCH,
- )
- from torch.torch_version import __version__ as __version__
- if TYPE_CHECKING:
- from torch.types import Device, IntLikeType
- __all__ = [
- "BoolStorage",
- "BoolTensor",
- "ByteStorage",
- "ByteTensor",
- "CharStorage",
- "CharTensor",
- "DoubleStorage",
- "DoubleTensor",
- "FloatStorage",
- "FloatTensor",
- "GradScaler",
- "IntStorage",
- "IntTensor",
- "LongStorage",
- "LongTensor",
- "ShortStorage",
- "ShortTensor",
- "SymBool",
- "SymFloat",
- "SymInt",
- "Tensor",
- "TypedStorage",
- "UntypedStorage",
- "are_deterministic_algorithms_enabled",
- "autocast",
- "chunk",
- "compile",
- "cond",
- "enable_grad",
- "export",
- "get_default_device",
- "get_deterministic_debug_mode",
- "get_device_module",
- "get_float32_matmul_precision",
- "get_rng_state",
- "inference_mode",
- "initial_seed",
- "is_deterministic_algorithms_warn_only_enabled",
- "is_storage",
- "is_tensor",
- "is_warn_always_enabled",
- "load",
- "lobpcg",
- "manual_seed",
- "matmul",
- "no_grad",
- "rand",
- "randn",
- "save",
- "seed",
- "set_default_device",
- "set_default_tensor_type",
- "set_deterministic_debug_mode",
- "set_float32_matmul_precision",
- "set_printoptions",
- "set_rng_state",
- "set_warn_always",
- "split",
- "stack",
- "sym_float",
- "sym_fresh_size",
- "sym_int",
- "sym_ite",
- "sym_max",
- "sym_min",
- "sym_not",
- "sym_sum",
- "typename",
- "unravel_index",
- "use_deterministic_algorithms",
- "vmap",
- ]
- # Please keep this list sorted
- if __all__ != sorted(__all__):
- raise AssertionError("__all__ must be kept sorted")
- ################################################################################
- # Load the extension module
- ################################################################################
- # If PyTorch was built against the ROCm runtime wheels, then there will be
- # a _rocm_init module and it will define an initialize() function which can
- # prepare ROCm for use. See general documentation on ROCm runtime wheels:
- # https://github.com/ROCm/TheRock/blob/main/docs/packaging/python_packaging.md
- # Since this module is only ever added to the wheel if built for such a
- # deployment, it is always safe to attempt.
- try:
- from . import _rocm_init # type: ignore[attr-defined]
- except ImportError:
- pass
- else:
- _rocm_init.initialize()
- del _rocm_init
- if sys.platform == "win32":
- def _load_dll_libraries() -> None:
- import sysconfig
- from torch.version import cuda as cuda_version
- pfiles_path = os.getenv("ProgramFiles", r"C:\Program Files")
- py_dll_path = os.path.join(sys.exec_prefix, "Library", "bin")
- th_dll_path = os.path.join(os.path.dirname(__file__), "lib")
- usebase_path = os.path.join(
- sysconfig.get_config_var("userbase"), "Library", "bin"
- )
- py_root_bin_path = os.path.join(sys.exec_prefix, "bin")
- # When users create a virtualenv that inherits the base environment,
- # we will need to add the corresponding library directory into
- # DLL search directories. Otherwise, it will rely on `PATH` which
- # is dependent on user settings.
- if sys.exec_prefix != sys.base_exec_prefix:
- base_py_dll_path = os.path.join(sys.base_exec_prefix, "Library", "bin")
- else:
- base_py_dll_path = ""
- dll_paths = [
- p
- for p in (
- th_dll_path,
- py_dll_path,
- base_py_dll_path,
- usebase_path,
- py_root_bin_path,
- )
- if os.path.exists(p)
- ]
- if not builtins.any(
- os.path.exists(os.path.join(p, "nvToolsExt64_1.dll")) for p in dll_paths
- ):
- nvtoolsext_dll_path = os.path.join(
- os.getenv(
- "NVTOOLSEXT_PATH",
- os.path.join(pfiles_path, "NVIDIA Corporation", "NvToolsExt"),
- ),
- "bin",
- "x64",
- )
- else:
- nvtoolsext_dll_path = ""
- if cuda_version and builtins.all(
- not glob.glob(os.path.join(p, "cudart64*.dll")) for p in dll_paths
- ):
- cuda_version_1 = cuda_version.replace(".", "_")
- cuda_path_var = "CUDA_PATH_V" + cuda_version_1
- default_path = os.path.join(
- pfiles_path, "NVIDIA GPU Computing Toolkit", "CUDA", f"v{cuda_version}"
- )
- cuda_path = os.path.join(os.getenv(cuda_path_var, default_path), "bin")
- else:
- cuda_path = ""
- dll_paths.extend(
- p for p in (nvtoolsext_dll_path, cuda_path) if os.path.exists(p)
- )
- kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
- with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
- prev_error_mode = kernel32.SetErrorMode(0x0001)
- kernel32.LoadLibraryW.restype = ctypes.c_void_p
- if with_load_library_flags:
- kernel32.LoadLibraryExW.restype = ctypes.c_void_p
- for dll_path in dll_paths:
- os.add_dll_directory(dll_path)
- try:
- ctypes.CDLL("vcruntime140.dll")
- ctypes.CDLL("msvcp140.dll")
- if platform.machine() != "ARM64":
- ctypes.CDLL("vcruntime140_1.dll")
- except OSError:
- print(
- textwrap.dedent(
- """
- Microsoft Visual C++ Redistributable is not installed, this may lead to the DLL load failure.
- It can be downloaded at https://aka.ms/vs/17/release/vc_redist.x64.exe
- """
- ).strip()
- )
- dlls = glob.glob(os.path.join(th_dll_path, "*.dll"))
- path_patched = False
- for dll in dlls:
- is_loaded = False
- if with_load_library_flags:
- res = kernel32.LoadLibraryExW(dll, None, 0x00001100)
- last_error = ctypes.get_last_error()
- if res is None and last_error != 126:
- err = ctypes.WinError(last_error)
- err.strerror += (
- f' Error loading "{dll}" or one of its dependencies.'
- )
- raise err
- elif res is not None:
- is_loaded = True
- if not is_loaded:
- if not path_patched:
- os.environ["PATH"] = ";".join(dll_paths + [os.environ["PATH"]])
- path_patched = True
- res = kernel32.LoadLibraryW(dll)
- if res is None:
- err = ctypes.WinError(ctypes.get_last_error())
- err.strerror += (
- f' Error loading "{dll}" or one of its dependencies.'
- )
- raise err
- kernel32.SetErrorMode(prev_error_mode)
- _load_dll_libraries()
- del _load_dll_libraries
- def _get_cuda_dep_paths(path: str, lib_folder: str, lib_name: str) -> list[str]:
- # Libraries can either be in
- # path/nvidia/lib_folder/lib or
- # path/nvidia/cuXX/lib (since CUDA 13.0) or
- # path/lib_folder/lib
- from torch.version import cuda as cuda_version
- nvidia_lib_paths = glob.glob(
- os.path.join(path, "nvidia", lib_folder, "lib", lib_name)
- )
- if cuda_version is not None:
- maj_cuda_version = cuda_version.split(".")[0]
- nvidia_lib_paths += glob.glob(
- os.path.join(path, "nvidia", f"cu{maj_cuda_version}", "lib", lib_name)
- )
- lib_paths = glob.glob(os.path.join(path, lib_folder, "lib", lib_name))
- return nvidia_lib_paths + lib_paths
- def _preload_cuda_lib(lib_folder: str, lib_name: str, required: bool = True) -> None: # type: ignore[valid-type]
- """Preloads cuda library if it could not be found otherwise."""
- # Should only be called on Linux if default path resolution have failed
- if platform.system() != "Linux":
- raise AssertionError(f"Should only be called on Linux, got {platform.system()}")
- lib_path = None
- for path in sys.path:
- candidate_lib_paths = _get_cuda_dep_paths(path, lib_folder, lib_name)
- if candidate_lib_paths:
- lib_path = candidate_lib_paths[0]
- break
- if not lib_path and required:
- raise ValueError(f"{lib_name} not found in the system path {sys.path}")
- if lib_path:
- ctypes.CDLL(lib_path)
- def _preload_cuda_deps(err: OSError | None = None) -> None:
- cuda_libs: list[tuple[str, str]] = [
- # NOTE: Order matters! We must preload libcublasLt BEFORE libcublas to prevent
- # libcublas from loading a mismatched system-wide libcublasLt via its RUNPATH.
- # Without this, if a different CUDA Toolkit version exists in the system PATH,
- # libcublas may load the wrong libcublasLt, causing symbol errors or runtime failures.
- ("cublas", "libcublasLt.so.*[0-9]"),
- ("cublas", "libcublas.so.*[0-9]"),
- ("cudnn", "libcudnn.so.*[0-9]"),
- ("cuda_nvrtc", "libnvrtc.so.*[0-9]"),
- ("cuda_nvrtc", "libnvrtc-builtins.so.*[0-9]"),
- ("cuda_runtime", "libcudart.so.*[0-9]"),
- ("cuda_cupti", "libcupti.so.*[0-9]"),
- ("cufft", "libcufft.so.*[0-9]"),
- ("curand", "libcurand.so.*[0-9]"),
- ("nvjitlink", "libnvJitLink.so.*[0-9]"),
- ("cusparse", "libcusparse.so.*[0-9]"),
- ("cusparselt", "libcusparseLt.so.*[0-9]"),
- ("cusolver", "libcusolver.so.*[0-9]"),
- ("nccl", "libnccl.so.*[0-9]"),
- ("nvshmem", "libnvshmem_host.so.*[0-9]"),
- ("cufile", "libcufile.so.*[0-9]"),
- ]
- # If error is passed, re-raise it if it's not about one of the abovementioned
- # libraries
- if err is not None and not [
- lib for _, lib in cuda_libs if lib.split(".", 1)[0] in err.args[0]
- ]:
- raise err
- # Otherwise, try to preload dependencies from site-packages
- for lib_folder, lib_name in cuda_libs:
- _preload_cuda_lib(lib_folder, lib_name)
- # libnvToolsExt is Optional Dependency
- _preload_cuda_lib("nvtx", "libnvToolsExt.so.*[0-9]", required=False)
- # See Note [Global dependencies]
- def _load_global_deps() -> None:
- if platform.system() == "Windows":
- return
- # Determine the file extension based on the platform
- lib_ext = ".dylib" if platform.system() == "Darwin" else ".so"
- lib_name = f"libtorch_global_deps{lib_ext}"
- here = os.path.abspath(__file__)
- global_deps_lib_path = os.path.join(os.path.dirname(here), "lib", lib_name)
- try:
- ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL)
- # Workaround slim-wheel CUDA dependency bugs in cusparse and cudnn by preloading nvjitlink
- # and nvrtc. In CUDA-12.4+ cusparse depends on nvjitlink, but does not have rpath when
- # shipped as wheel, which results in OS picking wrong/older version of nvjitlink library
- # if `LD_LIBRARY_PATH` is defined, see https://github.com/pytorch/pytorch/issues/138460
- # Similar issue exist in cudnn that dynamically loads nvrtc, unaware of its relative path.
- # See https://github.com/pytorch/pytorch/issues/145580
- try:
- with open("/proc/self/maps") as f:
- _maps = f.read()
- # libtorch_global_deps.so always depends in cudart, check if its installed and loaded
- if "libcudart.so" not in _maps:
- return
- # If all above-mentioned conditions are met, preload CUDA dependencies
- _preload_cuda_deps()
- except Exception:
- pass
- except OSError as err:
- # Can happen for wheel with cuda libs as PYPI deps
- # As PyTorch is not purelib, but nvidia-*-cu12 is
- _preload_cuda_deps(err)
- ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL)
- if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv("TORCH_USE_RTLD_GLOBAL")) and (
- platform.system() != "Windows"
- ):
- # Do it the hard way. You might want to load libtorch with RTLD_GLOBAL in a
- # few circumstances:
- #
- # 1. You're in a build environment (e.g., fbcode) where
- # libtorch_global_deps is not available, but you still need
- # to get mkl to link in with RTLD_GLOBAL or it will just
- # not work.
- #
- # 2. You're trying to run PyTorch under UBSAN and you need
- # to ensure that only one copy of libtorch is loaded, so
- # vptr checks work properly
- #
- # If you're using this setting, you must verify that all the libraries
- # you load consistently use the same libstdc++, or you may have
- # mysterious segfaults.
- #
- old_flags = sys.getdlopenflags()
- sys.setdlopenflags(os.RTLD_GLOBAL | os.RTLD_LAZY)
- from torch._C import * # noqa: F403
- sys.setdlopenflags(old_flags)
- del old_flags
- else:
- # Easy way. You want this most of the time, because it will prevent
- # C++ symbols from libtorch clobbering C++ symbols from other
- # libraries, leading to mysterious segfaults.
- #
- # If building in an environment where libtorch_global_deps isn't available
- # like parts of fbsource, but where RTLD_GLOBAL causes segfaults, you will
- # want USE_RTLD_GLOBAL_WITH_LIBTORCH = False and USE_GLOBAL_DEPS = False
- #
- # See Note [Global dependencies]
- if USE_GLOBAL_DEPS:
- _load_global_deps()
- from torch._C import * # noqa: F403
- class SymInt:
- """
- Like an int (including magic methods), but redirects all operations on the
- wrapped node. This is used in particular to symbolically record operations
- in the symbolic shape workflow.
- """
- def __init__(self, node):
- # This field MUST be named node; C++ binding code assumes that this
- # class has a field named node that stores SymNode
- self.node = node
- def __bool__(self):
- return builtins.bool(self != 0)
- def __int__(self):
- return self.node.int_()
- def __index__(self):
- return self.node.int_()
- # Magic methods installed by torch.fx.experimental.sym_node
- def __round__(self, ndigits=None):
- return self
- def __truediv__(self, other):
- if isinstance(other, (builtins.float, SymFloat)):
- return sym_float(self).__float_truediv__(other)
- if not isinstance(other, (builtins.int, SymInt)):
- return NotImplemented
- return self.__int_truediv__(other)
- def __rtruediv__(self, other):
- if isinstance(other, (builtins.float, SymFloat)):
- return sym_float(self).__rfloat_truediv__(other)
- if not isinstance(other, (builtins.int, SymInt)):
- return NotImplemented
- return self.__rint_truediv__(other)
- def __floordiv__(self, other):
- if isinstance(other, (builtins.float, SymFloat)):
- return sym_float(math.floor(sym_float(self) / other))
- if not isinstance(other, (builtins.int, SymInt)):
- return NotImplemented
- return self.__int_floordiv__(other)
- def __rfloordiv__(self, other):
- if isinstance(other, (builtins.float, SymFloat)):
- return sym_float(math.floor(other / sym_float(self)))
- if not isinstance(other, (builtins.int, SymInt)):
- return NotImplemented
- return self.__rint_floordiv__(other)
- # nb: complex is impossible to handle correctly lol, with
- # negative base and integral float need to diverge semantics and
- # just always return complex. Neener neener pretend this problem
- # doesn't exist
- def __pow__(self, other):
- if isinstance(other, (builtins.float, SymFloat)):
- return sym_float(self).__pow__(other)
- if not isinstance(other, (builtins.int, SymInt)):
- return NotImplemented
- # Guards! This guard is necessary because we need to know it to
- # determine the output type of this operation
- if other >= 0:
- return self.__pow_by_natural__(other)
- else:
- # Mercifully, when the exponent is negative, Python just promotes
- # to doubles and does a float pow:
- #
- # if (Py_SIZE(b) < 0 && c == NULL) {
- # /* if exponent is negative and there's no modulus:
- # return a float. This works because we know
- # that this calls float_pow() which converts its
- # arguments to double. */
- # Py_DECREF(a);
- # Py_DECREF(b);
- # return PyFloat_Type.tp_as_number->nb_power(v, w, x);
- # }
- return sym_float(self).__pow__(sym_float(other))
- def __rpow__(self, other):
- if isinstance(other, (builtins.float, SymFloat)):
- return sym_float(self).__rpow__(other)
- if not isinstance(other, (builtins.int, SymInt)):
- return NotImplemented
- if self >= 0: # self is exponent
- return self.__rpow_by_natural__(other)
- else:
- return sym_float(self).__rpow__(sym_float(other))
- def __eq__(self, other: object) -> builtins.bool:
- raise TypeError("type stub not overridden")
- def __lt__(self, other) -> builtins.bool:
- raise TypeError("type stub not overridden")
- def __gt__(self, other) -> builtins.bool:
- raise TypeError("type stub not overridden")
- def __le__(self, other) -> builtins.bool:
- raise TypeError("type stub not overridden")
- def __ge__(self, other) -> builtins.bool:
- raise TypeError("type stub not overridden")
- def __add__(self, other) -> "SymInt":
- raise TypeError("type stub not overridden")
- def __radd__(self, other) -> "SymInt":
- raise TypeError("type stub not overridden")
- def __rmul__(self, other) -> "SymInt":
- raise TypeError("type stub not overridden")
- def __mod__(self, other: "IntLikeType") -> "SymInt":
- raise TypeError("type stub not overridden")
- def __mul__(self, other) -> "SymInt":
- raise TypeError("type stub not overridden")
- def __pow_by_natural__(self, other) -> "SymInt":
- raise TypeError("type stub not overridden")
- def __rpow_by_natural__(self, other) -> "SymInt":
- raise TypeError("type stub not overridden")
- def __int_truediv__(self, other) -> "SymFloat":
- raise TypeError("type stub not overridden")
- def __rint_truediv__(self, other) -> "SymFloat":
- raise TypeError("type stub not overridden")
- def __int_floordiv__(self, other) -> "SymFloat":
- raise TypeError("type stub not overridden")
- def __rint_floordiv__(self, other) -> "SymFloat":
- raise TypeError("type stub not overridden")
- def __sym_max__(self, other):
- raise TypeError("type stub not overridden")
- def __sym_min__(self, other):
- raise TypeError("type stub not overridden")
- def __sym_float__(self):
- raise TypeError("type stub not overridden")
- def __neg__(self):
- raise TypeError("type stub not overridden")
- def __sub__(self, other: "IntLikeType") -> "SymInt":
- raise TypeError("type stub not overridden")
- def __rsub__(self, other: "IntLikeType") -> "SymInt":
- raise TypeError("type stub not overridden")
- def __and__(self, other) -> "SymInt":
- raise TypeError("type stub not overridden")
- def __or__(self, other) -> "SymInt":
- raise TypeError("type stub not overridden")
- def __repr__(self):
- return self.node._graph_repr()
- def _sympy_(self):
- return self.node.expr
- def __hash__(self) -> builtins.int:
- if self.node.is_nested_int():
- return hash(self.node.nested_int())
- else:
- # We could support constant SymInts as well, but not doing it for now
- raise TypeError("unhashable type: non-nested SymInt")
- # TODO: Force specialization
- # This can't be done because the TypeError here is load bearing
- # for einops
- # https://github.com/arogozhnikov/einops/blob/6181e1e95dc58c00a3143c1726da1c6ee0463164/einops/einops.py#L237
- # return hash(builtins.int(self))
- def as_integer_ratio(self) -> tuple["SymInt", builtins.int]:
- """Represent this int as an exact integer ratio"""
- return self, 1
- def bit_length(self) -> builtins.int:
- # TODO: A more relaxed guard is possible here, where you guard to
- # allow all integer quantities which would result in the same bit
- # length. We can also just make a dedicated Sympy function for
- # computing this quantity and represent it symbolically.
- return builtins.int(self).bit_length()
- def conjugate(self) -> "SymInt":
- return self
- class SymFloat:
- """
- Like a float (including magic methods), but redirects all operations on the
- wrapped node. This is used in particular to symbolically record operations
- in the symbolic shape workflow.
- """
- def __init__(self, node):
- # This field MUST be named node; C++ binding code assumes that this
- # class has a field named node that stores SymNode
- self.node = node
- def __truediv__(self, other):
- if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
- return NotImplemented
- return self.__float_truediv__(sym_float(other))
- def __rtruediv__(self, other):
- if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
- return NotImplemented
- return self.__rfloat_truediv__(sym_float(other))
- def __floordiv__(self, other):
- if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
- return NotImplemented
- return sym_float(math.floor(self / sym_float(other)))
- def __rfloordiv__(self, other):
- if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
- return NotImplemented
- return sym_float(math.floor(sym_float(other) / self))
- def __bool__(self):
- return self.node.bool_()
- def __float__(self):
- return self.node.guard_float("", 0)
- def __int__(self):
- return self.__trunc__().__int__()
- # Symbolic power does NOT work with negative base, this is to avoid
- # potential complex outputs
- def __pow__(self, other):
- if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
- return NotImplemented
- torch._check(self >= 0)
- return self.__float_pow__(other)
- def __rpow__(self, other):
- if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
- return NotImplemented
- torch._check(other >= 0)
- return self.__rfloat_pow__(other)
- # Magic methods installed by torch.fx.experimental.sym_node
- def __eq__(self, other: object) -> builtins.bool:
- raise TypeError("type stub not overridden")
- def __lt__(self, other) -> builtins.bool:
- raise TypeError("type stub not overridden")
- def __gt__(self, other) -> builtins.bool:
- raise TypeError("type stub not overridden")
- def __le__(self, other) -> builtins.bool:
- raise TypeError("type stub not overridden")
- def __ge__(self, other) -> builtins.bool:
- raise TypeError("type stub not overridden")
- def __float_pow__(self, other) -> "SymFloat":
- raise TypeError("type stub not overridden")
- def __rfloat_pow__(self, other) -> "SymFloat":
- raise TypeError("type stub not overridden")
- def __float_truediv__(self, other) -> "SymFloat":
- raise TypeError("type stub not overridden")
- def __rfloat_truediv__(self, other) -> "SymFloat":
- raise TypeError("type stub not overridden")
- def __trunc__(self):
- raise TypeError("type stub not overridden")
- def __sym_max__(self, other):
- raise TypeError("type stub not overridden")
- def __sym_min__(self, other):
- raise TypeError("type stub not overridden")
- def __sym_int__(self):
- raise TypeError("type stub not overridden")
- def is_integer(self):
- """Return True if the float is an integer."""
- raise TypeError("type stub not overridden")
- def as_integer_ratio(self) -> tuple[builtins.int, builtins.int]:
- """Represent this float as an exact integer ratio"""
- return builtins.float(self).as_integer_ratio()
- def __repr__(self):
- return self.node._graph_repr()
- def _sympy_(self):
- return self.node.expr
- def __hash__(self):
- return hash(builtins.float(self))
- def conjugate(self) -> "SymFloat":
- """Returns the complex conjugate of the float."""
- return self
- def hex(self) -> str:
- """Returns the hexadecimal representation of the float."""
- return self.node.guard_float("", 0).hex()
- class SymBool:
- """
- Like a bool (including magic methods), but redirects all operations on the
- wrapped node. This is used in particular to symbolically record operations
- in the symbolic shape workflow.
- Unlike regular bools, regular boolean operators will force extra guards instead
- of symbolically evaluate. Use the bitwise operators instead to handle this.
- """
- def __init__(self, node):
- # This field MUST be named node; C++ binding code assumes that this
- # class has a field named node that stores SymNode
- self.node = node
- def __bool__(self):
- return self.node.bool_()
- def __int__(self):
- return builtins.int(self.node.bool_())
- # Magic methods installed by torch.fx.experimental.sym_node
- def __and__(self, other) -> "SymBool":
- raise TypeError("type stub not overridden")
- def __or__(self, other) -> "SymBool":
- raise TypeError("type stub not overridden")
- # We very carefully define __sym_not__, and not a number of other
- # plausible alternatives:
- #
- # - We do not override __not__ because this is not a real magic
- # method; you cannot override the meaning of the not builtin in
- # Python. We use the name 'sym_not' to clarify that in user code you
- # cannot use the builtin not or operator.not_ or operator.__not__ and
- # hit this magic method; you must use our custom sym_not operator.
- #
- # - We do not override the __invert__ method because SymBool is
- # meant to be usable in situations where bool is expected. However,
- # bitwise negation ~a does the wrong thing with booleans (because
- # bool is a subclass of int, so ~1 = -2 which is not falseish.)
- # This would be a giant footgun, so we get around it by defining
- # our own operator. Note that bitwise and/or do the right thing,
- # so we reuse the conventional operators there for readability.
- #
- def __sym_not__(self) -> "SymBool":
- raise TypeError("type stub not overridden")
- def __sym_ite__(self, then_val, else_val):
- raise TypeError("type stub not overridden")
- def __eq__(self, other) -> builtins.bool:
- raise TypeError("type stub not overridden")
- def __repr__(self):
- return self.node._graph_repr()
- def _sympy_(self):
- return self.node.expr
- def __hash__(self):
- if self.node.is_constant():
- return hash(self.node.bool_())
- else:
- # Force specialization
- return hash(builtins.bool(self))
- def __sym_float__(self):
- """
- Provides a SymFloat representation (0.0 or 1.0) for this SymBool.
- Called by torch.sym_float() when casting SymBool to float.
- """
- from torch.fx.experimental.sym_node import wrap_node
- return wrap_node(self.node.sym_float())
- def sym_not(a):
- r"""SymInt-aware utility for logical negation.
- Args:
- a (SymBool or bool): Object to negate
- """
- import sympy
- if overrides.has_torch_function_unary(a):
- return overrides.handle_torch_function(sym_not, (a,), a)
- if hasattr(a, "__sym_not__"):
- return a.__sym_not__()
- if isinstance(a, sympy.Basic):
- return ~a # type: ignore[operator]
- return not a
- def sym_float(a):
- r"""SymInt-aware utility for float casting.
- Args:
- a (SymInt, SymFloat, or object): Object to cast
- """
- if overrides.has_torch_function_unary(a):
- return overrides.handle_torch_function(sym_float, (a,), a)
- if isinstance(a, SymFloat):
- return a
- elif hasattr(a, "__sym_float__"):
- return a.__sym_float__()
- return builtins.float(a) # type: ignore[operator]
- def sym_int(a):
- r"""SymInt-aware utility for int casting.
- Args:
- a (SymInt, SymFloat, or object): Object to cast
- """
- if overrides.has_torch_function_unary(a):
- return overrides.handle_torch_function(sym_int, (a,), a)
- if isinstance(a, SymInt):
- return a
- elif isinstance(a, SymFloat):
- return math.trunc(a)
- return builtins.int(a) # type: ignore[operator]
- def sym_max(a, b):
- """
- SymInt-aware utility for max which avoids branching on a < b.
- Unlike builtins.max(), this only works for int/float, and it always
- promotes to float if any argument is float (unlike builtins.max, which
- will faithfully preserve the type of the input argument).
- """
- if overrides.has_torch_function((a, b)):
- return overrides.handle_torch_function(sym_max, (a, b), a, b)
- if isinstance(a, (SymInt, SymFloat)):
- return a.__sym_max__(b)
- elif isinstance(b, (SymInt, SymFloat)):
- # Due to promotion semantics, this is operator is commutative:
- # max(1, 1.0) === max(1.0, 1) === 1.0
- return b.__sym_max__(a)
- # TODO: Probably can make bool work too, just lazy
- all_types, float_types = __all_and_float_types()
- if not isinstance(a, all_types):
- raise AssertionError(f"expected {all_types}, got {type(a)}")
- if not isinstance(b, all_types):
- raise AssertionError(f"expected {all_types}, got {type(b)}")
- if isinstance(a, float_types) or isinstance(b, float_types):
- return builtins.float(builtins.max(a, b)) # type: ignore[call-overload]
- else:
- return builtins.max(a, b) # type: ignore[call-overload]
- def __all_and_float_types() -> tuple[tuple[type, ...], tuple[type, ...]]:
- try:
- import numpy as np
- all_types: tuple[type, ...] = (
- np.integer,
- np.floating,
- builtins.int,
- builtins.float,
- )
- float_types: tuple[type, ...] = (np.floating, builtins.float)
- except ModuleNotFoundError:
- all_types = (builtins.int, builtins.float)
- float_types = (builtins.float,)
- return all_types, float_types
- def sym_min(a, b):
- """SymInt-aware utility for min()."""
- if overrides.has_torch_function((a, b)):
- return overrides.handle_torch_function(sym_min, (a, b), a, b)
- if isinstance(a, (SymInt, SymFloat)):
- return a.__sym_min__(b)
- elif isinstance(b, (SymInt, SymFloat)):
- return b.__sym_min__(a)
- all_types, float_types = __all_and_float_types()
- if not isinstance(a, all_types):
- raise AssertionError(f"expected {all_types}, got {type(a)}")
- if not isinstance(b, all_types):
- raise AssertionError(f"expected {all_types}, got {type(b)}")
- if isinstance(a, float_types) or isinstance(b, float_types):
- return builtins.float(builtins.min(a, b)) # type: ignore[call-overload]
- else:
- return builtins.min(a, b) # type: ignore[call-overload]
- def sym_sum(args):
- """
- N-ary add which is faster to compute for long lists than iterated binary
- addition. Only does something special for integers.
- """
- if overrides.has_torch_function(args):
- return overrides.handle_torch_function(sym_sum, args, args)
- found = None
- for a in args:
- if not isinstance(a, (SymInt, builtins.int)):
- return builtins.sum(args)
- if isinstance(a, SymInt):
- found = a.node
- if found is None:
- return builtins.sum(args)
- from torch.fx.experimental.sym_node import to_node, wrap_node
- return wrap_node(found.sym_sum(tuple(to_node(found, a) for a in args)))
- # Drop in replacement for math.sqrt, math.sin, math.cos etc
- def _get_sym_math_fn(name):
- def fn(a):
- if overrides.has_torch_function_unary(a):
- return overrides.handle_torch_function(fn, (a,), a)
- if isinstance(a, SymInt):
- a = torch.sym_float(a)
- if hasattr(a, f"__sym_{name}__"):
- return getattr(a, f"__sym_{name}__")()
- return getattr(math, name)(a)
- return fn
- __fn, __name, __sym_name = None, "", ""
- for __name in (
- "sqrt",
- "cos",
- "cosh",
- "sin",
- "sinh",
- "tan",
- "tanh",
- "asin",
- "acos",
- "atan",
- "log2",
- ):
- __sym_name = f"_sym_{__name}"
- __fn = _get_sym_math_fn(__name)
- __fn.__qualname__ = __fn.__name__ = __sym_name
- globals()[__sym_name] = __fn
- del __fn, __name, __sym_name, _get_sym_math_fn
- # Adding temporary shortcut
- sym_sqrt = globals()["_sym_sqrt"]
- __all__.append("sym_sqrt")
- def sym_ite(b, t, f):
- """SymInt-aware utility for ternary operator (``t if b else f``.)"""
- if overrides.has_torch_function((b, t, f)):
- return overrides.handle_torch_function(sym_ite, (b, t, f), b, t, f)
- if not isinstance(b, (SymBool, builtins.bool)):
- raise AssertionError(f"expected SymBool or bool, got {type(b)}")
- if type(t) is not type(f):
- raise AssertionError(f"type mismatch: {type(t)} vs {type(f)}")
- if isinstance(b, SymBool):
- return b.__sym_ite__(t, f)
- return t if b else f
- # Create a fresh unbacked int, from an (possibly unbacked int) expression.
- def sym_fresh_size(expr):
- return torch.tensor(expr).item()
- # Check to see if we can load C extensions, and if not provide some guidance
- # on what the problem might be.
- try:
- # _initExtension is chosen (arbitrarily) as a sentinel.
- from torch._C import _initExtension
- except ImportError:
- import torch._C as _C_for_compiled_check
- if _C_for_compiled_check.__file__ is None:
- raise ImportError(
- textwrap.dedent(
- """
- Failed to load PyTorch C extensions:
- It appears that PyTorch has loaded the `torch/_C` folder
- of the PyTorch repository rather than the C extensions which
- are expected in the `torch._C` namespace. This can occur when
- using the `install` workflow. e.g.
- $ python -m pip install --no-build-isolation -v . && python -c "import torch"
- This error can generally be solved using the `develop` workflow
- $ python -m pip install --no-build-isolation -v -e . && python -c "import torch" # This should succeed
- or by running Python from a different directory.
- """
- ).strip()
- ) from None
- raise # If __file__ is not None the cause is unknown, so just re-raise.
- # The torch._C submodule is already loaded via `from torch._C import *` above
- # Make an explicit reference to the _C submodule to appease linters
- from torch import _C as _C
- __name, __obj = "", None
- for __name in dir(_C):
- if __name[0] != "_" and not __name.endswith("Base"):
- __all__.append(__name)
- __obj = getattr(_C, __name)
- if callable(__obj) or inspect.isclass(__obj):
- if __obj.__module__ != __name__: # "torch"
- # TODO: fix their module from C++ side
- if __name not in {
- "DisableTorchFunctionSubclass",
- "DisableTorchFunction",
- "Generator",
- }:
- __obj.__module__ = __name__ # "torch"
- elif __name == "TensorBase":
- # issue 109438 / pr 109940. Prevent TensorBase from being copied into torch.
- delattr(sys.modules[__name__], __name)
- del __name, __obj
- if not TYPE_CHECKING:
- # issue 38137 and python issue 43367. Submodules of a C extension are
- # non-standard, and attributes of those submodules cannot be pickled since
- # pickle expect to be able to import them as "from _C.sub import attr"
- # which fails with "_C is not a package
- def _import_extension_to_sys_modules(module, memo=None):
- if memo is None:
- memo = set()
- if module in memo:
- return
- memo.add(module)
- module_name = module.__name__
- for name in dir(module):
- member = getattr(module, name)
- member_name = getattr(member, "__name__", "")
- if inspect.ismodule(member) and member_name.startswith(module_name):
- sys.modules.setdefault(member_name, member)
- # Recurse for submodules (e.g., `_C._dynamo.eval_frame`)
- _import_extension_to_sys_modules(member, memo)
- _import_extension_to_sys_modules(_C)
- del _import_extension_to_sys_modules
- ################################################################################
- # Define basic utilities
- ################################################################################
- def typename(obj: _Any, /) -> str:
- """
- String representation of the type of an object.
- This function returns a fully qualified string representation of an object's type.
- Args:
- obj (object): The object whose type to represent
- Returns:
- str: the type of the object `o`
- Example:
- >>> x = torch.tensor([1, 2, 3])
- >>> torch.typename(x)
- 'torch.LongTensor'
- >>> torch.typename(torch.nn.Parameter)
- 'torch.nn.parameter.Parameter'
- """
- if isinstance(obj, torch.Tensor):
- return obj.type()
- module = getattr(obj, "__module__", "") or ""
- qualname = ""
- if hasattr(obj, "__qualname__"):
- qualname = obj.__qualname__
- elif hasattr(obj, "__name__"):
- qualname = obj.__name__
- else:
- module = obj.__class__.__module__ or ""
- qualname = obj.__class__.__qualname__
- if module in {"", "builtins"}:
- return qualname
- return f"{module}.{qualname}"
- def is_tensor(obj: _Any, /) -> _TypeIs["torch.Tensor"]:
- r"""Returns True if `obj` is a PyTorch tensor.
- Args:
- obj (object): Object to test
- Example::
- >>> x = torch.tensor([1, 2, 3])
- >>> torch.is_tensor(x)
- True
- """
- return isinstance(obj, torch.Tensor)
- def is_storage(obj: _Any, /) -> builtins.bool:
- r"""Returns True if `obj` is a PyTorch storage object.
- Args:
- obj (Object): Object to test
- Example::
- >>> import torch
- >>> # UntypedStorage (recommended)
- >>> tensor = torch.tensor([1, 2, 3])
- >>> storage = tensor.untyped_storage()
- >>> torch.is_storage(storage)
- True
- >>>
- >>> # TypedStorage (legacy)
- >>> typed_storage = torch.TypedStorage(5, dtype=torch.float32)
- >>> torch.is_storage(typed_storage)
- True
- >>>
- >>> # regular tensor (should return False)
- >>> torch.is_storage(tensor)
- False
- >>>
- >>> # non-storage object
- >>> torch.is_storage([1, 2, 3])
- False
- """
- return type(obj) in _storage_classes
- _GLOBAL_DEVICE_CONTEXT = threading.local()
- def get_default_device() -> "torch.device":
- r"""Gets the default ``torch.Tensor`` to be allocated on ``device``"""
- global _GLOBAL_DEVICE_CONTEXT
- from torch.overrides import _get_current_function_mode_stack
- from torch.utils._device import DeviceContext
- def _get_device_with_index(device):
- if device.index is not None:
- return device
- else:
- # TODO: Call like get_device_index() method corresponding to
- # each device type
- return torch.tensor([]).device
- # Get device from any active DeviceContext.
- device_mode = next(
- filter(
- lambda mode: isinstance(mode, DeviceContext),
- reversed(_get_current_function_mode_stack()),
- ),
- None,
- )
- if device_mode:
- device = device_mode.device
- return _get_device_with_index(device)
- device_context = getattr(_GLOBAL_DEVICE_CONTEXT, "device_context", None)
- if device_context is not None:
- return _get_device_with_index(device_context.device)
- return torch.device("cpu")
- def set_default_device(device: "Device") -> None:
- """Sets the default ``torch.Tensor`` to be allocated on ``device``. This
- does not affect factory function calls which are called with an explicit
- ``device`` argument. Factory calls will be performed as if they
- were passed ``device`` as an argument.
- To only temporarily change the default device instead of setting it
- globally, use ``with torch.device(device):`` instead.
- The default device is initially ``cpu``. If you set the default tensor
- device to another device (e.g., ``cuda``) without a device index, tensors
- will be allocated on whatever the current device for the device type,
- even after :func:`torch.cuda.set_device` is called.
- .. warning::
- This function imposes a slight performance cost on every Python
- call to the torch API (not just factory functions). If this
- is causing problems for you, please comment on
- https://github.com/pytorch/pytorch/issues/92701
- .. note::
- This doesn't affect functions that create tensors that share the same memory as the input, like:
- :func:`torch.from_numpy` and :func:`torch.frombuffer`
- Args:
- device (device or string): the device to set as default
- Example::
- >>> # xdoctest: +SKIP("requires cuda, changes global state")
- >>> torch.get_default_device()
- device(type='cpu')
- >>> torch.set_default_device('cuda') # current device is 0
- >>> torch.get_default_device()
- device(type='cuda', index=0)
- >>> torch.set_default_device('cuda')
- >>> torch.cuda.set_device('cuda:1') # current device is 1
- >>> torch.get_default_device()
- device(type='cuda', index=1)
- >>> torch.set_default_device('cuda:1')
- >>> torch.get_default_device()
- device(type='cuda', index=1)
- """
- global _GLOBAL_DEVICE_CONTEXT
- if hasattr(_GLOBAL_DEVICE_CONTEXT, "device_context"):
- device_context = _GLOBAL_DEVICE_CONTEXT.device_context
- if device_context is not None:
- device_context.__exit__(None, None, None)
- if device is None:
- device_context = None
- else:
- from torch.utils._device import DeviceContext
- device_context = DeviceContext(device)
- device_context.__enter__()
- _GLOBAL_DEVICE_CONTEXT.device_context = device_context
- def set_default_tensor_type(t: type["torch.Tensor"] | str, /) -> None:
- r"""
- .. warning::
- This function is deprecated as of PyTorch 2.1, please use :func:`torch.set_default_dtype()` and
- :func:`torch.set_default_device()` as alternatives.
- Sets the default ``torch.Tensor`` type to floating point tensor type
- ``t``. This type will also be used as default floating point type for
- type inference in :func:`torch.tensor`.
- The default floating point tensor type is initially ``torch.FloatTensor``.
- Args:
- t (type or string): the floating point tensor type or its name
- Example::
- >>> # xdoctest: +SKIP("Other tests may have changed the default type. Can we reset it?")
- >>> torch.tensor([1.2, 3]).dtype # initial default for floating point is torch.float32
- torch.float32
- >>> torch.set_default_tensor_type(torch.DoubleTensor)
- >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor
- torch.float64
- """
- if isinstance(t, str):
- t = _import_dotted_name(t)
- _C._set_default_tensor_type(t)
- def set_default_dtype(d: "torch.dtype", /) -> None:
- r"""
- Sets the default floating point dtype to :attr:`d`. Supports floating point dtype
- as inputs. Other dtypes will cause torch to raise an exception.
- When PyTorch is initialized its default floating point dtype is torch.float32,
- and the intent of set_default_dtype(torch.float64) is to facilitate NumPy-like
- type inference. The default floating point dtype is used to:
- 1. Implicitly determine the default complex dtype. When the default floating type is float16,
- the default complex dtype is complex32. For float32, the default complex dtype is complex64.
- For float64, it is complex128. For bfloat16, an exception will be raised because
- there is no corresponding complex type for bfloat16.
- 2. Infer the dtype for tensors constructed using Python floats or complex Python
- numbers. See examples below.
- 3. Determine the result of type promotion between bool and integer tensors and
- Python floats and complex Python numbers.
- Args:
- d (:class:`torch.dtype`): the floating point dtype to make the default.
- Example:
- >>> # xdoctest: +SKIP("Other tests may have changed the default type. Can we reset it?")
- >>> # initial default for floating point is torch.float32
- >>> # Python floats are interpreted as float32
- >>> torch.tensor([1.2, 3]).dtype
- torch.float32
- >>> # initial default for floating point is torch.complex64
- >>> # Complex Python numbers are interpreted as complex64
- >>> torch.tensor([1.2, 3j]).dtype
- torch.complex64
- >>> torch.set_default_dtype(torch.float64)
- >>> # Python floats are now interpreted as float64
- >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor
- torch.float64
- >>> # Complex Python numbers are now interpreted as complex128
- >>> torch.tensor([1.2, 3j]).dtype # a new complex tensor
- torch.complex128
- >>> torch.set_default_dtype(torch.float16)
- >>> # Python floats are now interpreted as float16
- >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor
- torch.float16
- >>> # Complex Python numbers are now interpreted as complex128
- >>> torch.tensor([1.2, 3j]).dtype # a new complex tensor
- torch.complex32
- """
- _C._set_default_dtype(d)
- def use_deterministic_algorithms(
- mode: builtins.bool,
- *,
- warn_only: builtins.bool = False,
- ) -> None:
- r"""Sets whether PyTorch operations must use "deterministic"
- algorithms. That is, algorithms which, given the same input, and when
- run on the same software and hardware, always produce the same output.
- When enabled, operations will use deterministic algorithms when available,
- and if only nondeterministic algorithms are available they will throw a
- :class:`RuntimeError` when called.
- .. note:: This setting alone is not always enough to make an application
- reproducible. Refer to :ref:`reproducibility` for more information.
- .. note:: :func:`torch.set_deterministic_debug_mode` offers an alternative
- interface for this feature.
- The following normally-nondeterministic operations will act
- deterministically when ``mode=True``:
- * :class:`torch.nn.Conv1d` when called on CUDA tensor
- * :class:`torch.nn.Conv2d` when called on CUDA tensor
- * :class:`torch.nn.Conv3d` when called on CUDA tensor
- * :class:`torch.nn.ConvTranspose1d` when called on CUDA tensor
- * :class:`torch.nn.ConvTranspose2d` when called on CUDA tensor
- * :class:`torch.nn.ConvTranspose3d` when called on CUDA tensor
- * :class:`torch.nn.ReplicationPad1d` when attempting to differentiate a CUDA tensor
- * :class:`torch.nn.ReplicationPad2d` when attempting to differentiate a CUDA tensor
- * :class:`torch.nn.ReplicationPad3d` when attempting to differentiate a CUDA tensor
- * :func:`torch.bmm` when called on sparse-dense CUDA tensors
- * :func:`torch.Tensor.__getitem__` when attempting to differentiate a CPU tensor
- and the index is a list of tensors
- * :func:`torch.Tensor.index_put` with ``accumulate=False``
- * :func:`torch.Tensor.index_put` with ``accumulate=True`` when called on a CPU
- tensor
- * :func:`torch.Tensor.put_` with ``accumulate=True`` when called on a CPU
- tensor
- * :func:`torch.Tensor.scatter_add_` when called on a CUDA tensor
- * :func:`torch.gather` when called on a CUDA tensor that requires grad
- * :func:`torch.index_add` when called on CUDA tensor
- * :func:`torch.index_select` when attempting to differentiate a CUDA tensor
- * :func:`torch.repeat_interleave` when attempting to differentiate a CUDA tensor
- * :func:`torch.Tensor.index_copy` when called on a CPU or CUDA tensor
- * :func:`torch.Tensor.scatter` when `src` type is Tensor and called on CUDA tensor
- * :func:`torch.Tensor.scatter_reduce` when ``reduce='sum'`` or ``reduce='mean'`` and called on CUDA tensor
- The following normally-nondeterministic operations will throw a
- :class:`RuntimeError` when ``mode=True``:
- * :class:`torch.nn.AvgPool3d` when attempting to differentiate a CUDA tensor
- * :class:`torch.nn.AdaptiveAvgPool2d` when attempting to differentiate a CUDA tensor
- * :class:`torch.nn.AdaptiveAvgPool3d` when attempting to differentiate a CUDA tensor
- * :class:`torch.nn.MaxPool3d` when attempting to differentiate a CUDA tensor
- * :class:`torch.nn.AdaptiveMaxPool2d` when attempting to differentiate a CUDA tensor
- * :class:`torch.nn.FractionalMaxPool2d` when attempting to differentiate a CUDA tensor
- * :class:`torch.nn.FractionalMaxPool3d` when attempting to differentiate a CUDA tensor
- * :class:`torch.nn.MaxUnpool1d`
- * :class:`torch.nn.MaxUnpool2d`
- * :class:`torch.nn.MaxUnpool3d`
- * :func:`torch.nn.functional.interpolate` when attempting to differentiate a CUDA tensor
- and one of the following modes is used:
- - ``linear``
- - ``bilinear``
- - ``bicubic``
- - ``trilinear``
- * :class:`torch.nn.ReflectionPad1d` when attempting to differentiate a CUDA tensor
- * :class:`torch.nn.ReflectionPad2d` when attempting to differentiate a CUDA tensor
- * :class:`torch.nn.ReflectionPad3d` when attempting to differentiate a CUDA tensor
- * :class:`torch.nn.NLLLoss` when called on a CUDA tensor
- * :class:`torch.nn.CTCLoss` when attempting to differentiate a CUDA tensor
- * :class:`torch.nn.EmbeddingBag` when attempting to differentiate a CUDA tensor when
- ``mode='max'``
- * :func:`torch.Tensor.put_` when ``accumulate=False``
- * :func:`torch.Tensor.put_` when ``accumulate=True`` and called on a CUDA tensor
- * :func:`torch.histc` when called on a CUDA tensor
- * :func:`torch.bincount` when called on a CUDA tensor and ``weights``
- tensor is given
- * :func:`torch.median` with indices output when called on a CUDA tensor
- * :func:`torch.nn.functional.grid_sample` when attempting to differentiate a CUDA tensor
- * :func:`torch.cumsum` when called on a CUDA tensor when dtype is floating point or complex
- * :func:`torch.Tensor.scatter_reduce` when ``reduce='prod'`` and called on CUDA tensor
- * :func:`torch.Tensor.resize_` when called with a quantized tensor
- In addition, several operations fill uninitialized memory when this setting
- is turned on and when
- :attr:`torch.utils.deterministic.fill_uninitialized_memory` is turned on.
- See the documentation for that attribute for more information.
- Note that deterministic operations tend to have worse performance than
- nondeterministic operations.
- When this setting is turned on, the Inductor deterministic mode is also tuned on
- automatically. In deterministic mode, Inductor would avoid doing on device benchmarking
- that affect numerics. This includes:
- - don't pad matmul input shapes. Without enabling deterministic mode, Inductor would do
- benchmarking to check if padding matmul shape is beneficial.
- - don't autotune templates. Inductor has templates for kernels like matmul/conv/attention.
- Without enabling deterministic mode, Inductor would do autotuning to
- pick the best configs for those templates and adopt it if it's faster
- than the kernel in eager mode. In deterministic mode, we pick the eager kernel.
- - don't autotune triton configs for reduction. Reduction numerics are
- very sensitive to triton configs. In deterministic mode, Inductor
- will use some heuristics to pick the most promising configs rather
- than do autotuning.
- - Skip autotuning for reduction in coordinate descent tuning.
- - Don't benchmarking for the computation/communication reordering pass
- - Disable the feature that dynamically scale down RBLOCK triton config for higher
- occupancy.
- .. note::
- This flag does not detect or prevent nondeterministic behavior caused
- by calling an inplace operation on a tensor with an internal memory
- overlap or by giving such a tensor as the :attr:`out` argument for an
- operation. In these cases, multiple writes of different data may target
- a single memory location, and the order of writes is not guaranteed.
- Args:
- mode (:class:`bool`): If True, makes potentially nondeterministic
- operations switch to a deterministic algorithm or throw a runtime
- error. If False, allows nondeterministic operations.
- Keyword args:
- warn_only (:class:`bool`, optional): If True, operations that do not
- have a deterministic implementation will throw a warning instead of
- an error. Default: ``False``
- Example::
- >>> # xdoctest: +SKIP
- >>> torch.use_deterministic_algorithms(True)
- # Backward mode nondeterministic error
- >>> torch.nn.AvgPool3d(1)(torch.randn(3, 4, 5, 6, requires_grad=True).cuda()).sum().backward()
- ...
- RuntimeError: avg_pool3d_backward_cuda does not have a deterministic implementation...
- """
- import torch._inductor.config as inductor_config
- inductor_config.deterministic = mode
- _C._set_deterministic_algorithms(mode, warn_only=warn_only)
- def are_deterministic_algorithms_enabled() -> builtins.bool:
- r"""Returns True if the global deterministic flag is turned on. Refer to
- :func:`torch.use_deterministic_algorithms` documentation for more details.
- """
- return _C._get_deterministic_algorithms()
- def is_deterministic_algorithms_warn_only_enabled() -> builtins.bool:
- r"""Returns True if the global deterministic flag is set to warn only.
- Refer to :func:`torch.use_deterministic_algorithms` documentation for more
- details.
- """
- return _C._get_deterministic_algorithms_warn_only()
- def set_deterministic_debug_mode(debug_mode: builtins.int | str) -> None:
- r"""Sets the debug mode for deterministic operations.
- .. note:: This is an alternative interface for
- :func:`torch.use_deterministic_algorithms`. Refer to that function's
- documentation for details about affected operations.
- Args:
- debug_mode(str or int): If "default" or 0, don't error or warn on
- nondeterministic operations. If "warn" or 1, warn on
- nondeterministic operations. If "error" or 2, error on
- nondeterministic operations.
- """
- # NOTE: builtins.int is used here because int in this scope resolves
- # to torch.int
- if not isinstance(debug_mode, (builtins.int, str)):
- raise TypeError(f"debug_mode must be str or int, but got {type(debug_mode)}")
- if isinstance(debug_mode, str):
- if debug_mode == "default":
- debug_mode = 0
- elif debug_mode == "warn":
- debug_mode = 1
- elif debug_mode == "error":
- debug_mode = 2
- else:
- raise RuntimeError(
- "invalid value of debug_mode, expected one of `default`, "
- f"`warn`, `error`, but got {debug_mode}"
- )
- if debug_mode == 0:
- _C._set_deterministic_algorithms(False)
- elif debug_mode == 1:
- _C._set_deterministic_algorithms(True, warn_only=True)
- elif debug_mode == 2:
- _C._set_deterministic_algorithms(True)
- else:
- raise RuntimeError(
- f"invalid value of debug_mode, expected 0, 1, or 2, but got {debug_mode}"
- )
- def get_deterministic_debug_mode() -> builtins.int:
- r"""Returns the current value of the debug mode for deterministic
- operations. Refer to :func:`torch.set_deterministic_debug_mode`
- documentation for more details.
- """
- if _C._get_deterministic_algorithms():
- if _C._get_deterministic_algorithms_warn_only():
- return 1
- else:
- return 2
- else:
- return 0
- def get_float32_matmul_precision() -> str:
- r"""Returns the current value of float32 matrix multiplication precision. Refer to
- :func:`torch.set_float32_matmul_precision` documentation for more details.
- """
- return _C._get_float32_matmul_precision()
- def set_float32_matmul_precision(precision: str) -> None:
- r"""Sets the internal precision of float32 matrix multiplications.
- Running float32 matrix multiplications in lower precision may significantly increase
- performance, and in some programs the loss of precision has a negligible impact.
- Supports three settings:
- * "highest", float32 matrix multiplications use the float32 datatype (24 mantissa
- bits with 23 bits explicitly stored) for internal computations.
- * "high", float32 matrix multiplications either use the TensorFloat32 datatype (10
- mantissa bits explicitly stored) or treat each float32 number as the sum of two bfloat16 numbers
- (approximately 16 mantissa bits with 14 bits explicitly stored), if the appropriate fast matrix multiplication
- algorithms are available. Otherwise float32 matrix multiplications are computed
- as if the precision is "highest". See below for more information on the bfloat16
- approach.
- * "medium", float32 matrix multiplications use the bfloat16 datatype (8 mantissa
- bits with 7 bits explicitly stored) for internal computations, if a fast matrix multiplication algorithm
- using that datatype internally is available. Otherwise float32
- matrix multiplications are computed as if the precision is "high".
- When using "high" precision, float32 multiplications may use a bfloat16-based algorithm
- that is more complicated than simply truncating to some smaller number mantissa bits
- (e.g. 10 for TensorFloat32, 7 for bfloat16 explicitly stored). Refer to [Henry2019]_ for a complete
- description of this algorithm. To briefly explain here, the first step is to realize
- that we can perfectly encode a single float32 number as the sum of three bfloat16
- numbers (because float32 has 23 mantissa bits while bfloat16 has 7 explicitly stored, and both have the
- same number of exponent bits). This means that the product of two float32 numbers can
- be exactly given by the sum of nine products of bfloat16 numbers. We can then trade
- accuracy for speed by dropping some of these products. The "high" precision algorithm
- specifically keeps only the three most significant products, which conveniently excludes
- all of the products involving the last 8 mantissa bits of either input. This means that
- we can represent our inputs as the sum of two bfloat16 numbers rather than three.
- Because bfloat16 fused-multiply-add (FMA) instructions are typically >10x faster than
- float32 ones, it's faster to do three multiplications and 2 additions with bfloat16
- precision than it is to do a single multiplication with float32 precision.
- .. [Henry2019] http://arxiv.org/abs/1904.06376
- .. note::
- This does not change the output dtype of float32 matrix multiplications,
- it controls how the internal computation of the matrix multiplication is performed.
- .. note::
- This does not change the precision of convolution operations. Other flags,
- like `torch.backends.cudnn.allow_tf32`, may control the precision of convolution
- operations.
- .. note::
- This flag currently only affects one native device type: CUDA.
- If "high" or "medium" are set then the TensorFloat32 datatype will be used
- when computing float32 matrix multiplications, equivalent to setting
- `torch.backends.cuda.matmul.allow_tf32 = True`. When "highest" (the default)
- is set then the float32 datatype is used for internal computations, equivalent
- to setting `torch.backends.cuda.matmul.allow_tf32 = False`.
- Args:
- precision(str): can be set to "highest" (default), "high", or "medium" (see above).
- """
- _C._set_float32_matmul_precision(precision)
- def set_warn_always(b: builtins.bool, /) -> None:
- r"""When this flag is False (default) then some PyTorch warnings may only
- appear once per process. This helps avoid excessive warning information.
- Setting it to True causes these warnings to always appear, which may be
- helpful when debugging.
- Args:
- b (:class:`bool`): If True, force warnings to always be emitted
- If False, set to the default behaviour
- """
- _C._set_warnAlways(b)
- def is_warn_always_enabled() -> builtins.bool:
- r"""Returns True if the global warn_always flag is turned on. Refer to
- :func:`torch.set_warn_always` documentation for more details.
- """
- return _C._get_warnAlways()
- ################################################################################
- # Define error checking functions
- ################################################################################
- # These error checking functions must be kept consistent with their C++
- # equivalents. Their C++ equivalents are mentioned where applicable.
- def _check_with(
- error_type,
- cond: builtins.bool | SymBool,
- message: _Callable[[], str],
- ): # noqa: F811
- if not isinstance(cond, (builtins.bool, SymBool)):
- raise TypeError(f"cond must be a bool, but got {type(cond)}")
- from torch.fx.experimental.symbolic_shapes import expect_true
- if expect_true(cond):
- return
- # error_type must be a subclass of Exception and not subclass of Warning
- if not issubclass(error_type, Exception) or issubclass(error_type, Warning):
- raise AssertionError(
- f"error_type must be a subclass of Exception but not Warning, got {error_type}"
- )
- if message is None:
- message_evaluated = (
- "Expected cond to be True, but got False. (Could this error "
- "message be improved? If so, please report an enhancement request "
- "to PyTorch.)"
- )
- else:
- if not callable(message):
- raise TypeError("message must be a callable")
- message_evaluated = str(message())
- raise error_type(message_evaluated)
- def _check(cond, message=None): # noqa: F811
- r"""Throws error containing an optional message if the specified condition
- is False.
- Error type: ``RuntimeError``
- C++ equivalent: ``TORCH_CHECK``
- Args:
- cond (:class:`bool`): If False, throw error
- message (Callable, optional): Callable that returns either a string or
- an object that has a ``__str__()`` method to be used as the error
- message. Default: ``None``
- """
- _check_with(RuntimeError, cond, message) # pyrefly: ignore [bad-argument-type]
- @_deprecated(
- "_check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. \
- Use _check(i >= 0) instead.",
- category=FutureWarning,
- )
- def _check_is_size(i, message=None, *, max=None):
- """Checks that a given integer is a valid size (i.e., is non-negative).
- You should use this over ``_check(i >= 0)`` because it can prevent
- ``GuardOnDataDependentSymNode`` exceptions by opting yourself into alternate
- semantics for ``guard_size_oblivious`` tests that treat values 0 and 1
- equivalently to all other values.
- When max is not None, this specifies an upper bound equivalent to
- ``_check(i <= max)``. This bound is also subject to alternate semantics:
- in ``guard_size_oblivious`` tests, we assume that a constant max bound is
- treated equivalently to all other values. Symbolic max bounds are not yet
- supported.
- NB: Do NOT use this in contexts where a -1 size would be valid (indicating
- to infer the size from context, or if you should wrap-around or truncate).
- Only use this if the only valid value is an honest to goodness size.
- """
- # This is responsible for the expect_true
- _check(i >= 0, message)
- from torch.fx.experimental.symbolic_shapes import _advise_is_size
- _advise_is_size(i)
- if max is not None:
- _check(i <= max, message)
- from torch.fx.experimental.symbolic_shapes import _advise_is_bounded
- _advise_is_bounded(i, max)
- def _check_index(cond, message=None): # noqa: F811
- r"""Throws error containing an optional message if the specified condition
- is False.
- Error type: ``IndexError``
- C++ equivalent: ``TORCH_CHECK_INDEX``
- Args:
- cond (:class:`bool`): If False, throw error
- message (Callable, optional): Callable that returns either a string or
- an object that has a ``__str__()`` method to be used as the error
- message. Default: ``None``
- """
- _check_with(IndexError, cond, message) # pyrefly: ignore [bad-argument-type]
- def _check_value(cond, message=None): # noqa: F811
- r"""Throws error containing an optional message if the specified condition
- is False.
- Error type: ``ValueError``
- C++ equivalent: ``TORCH_CHECK_VALUE``
- Args:
- cond (:class:`bool`): If False, throw error
- message (Callable, optional): Callable that returns either a string or
- an object that has a ``__str__()`` method to be used as the error
- message. Default: ``None``
- """
- _check_with(ValueError, cond, message) # pyrefly: ignore [bad-argument-type]
- def _check_type(cond, message=None): # noqa: F811
- r"""Throws error containing an optional message if the specified condition
- is False.
- Error type: ``TypeError``
- C++ equivalent: ``TORCH_CHECK_TYPE``
- Args:
- cond (:class:`bool`): If False, throw error
- message (Callable, optional): Callable that returns either a string or
- an object that has a ``__str__()`` method to be used as the error
- message. Default: ``None``
- """
- _check_with(TypeError, cond, message) # pyrefly: ignore [bad-argument-type]
- def _check_not_implemented(cond, message=None): # noqa: F811
- r"""Throws error containing an optional message if the specified condition
- is False.
- Error type: ``NotImplementedError``
- C++ equivalent: ``TORCH_CHECK_NOT_IMPLEMENTED``
- Args:
- cond (:class:`bool`): If False, throw error
- message (Callable, optional): Callable that returns either a string or
- an object that has a ``__str__()`` method to be used as the error
- message. Default: ``None``
- """
- _check_with(
- NotImplementedError,
- cond,
- # pyrefly: ignore [bad-argument-type]
- message,
- )
- def _check_tensor_all_with(error_type, cond, message=None): # noqa: F811
- if not is_tensor(cond):
- raise TypeError(f"cond must be a tensor, but got {type(cond)}")
- if not cond.dtype == torch.bool:
- raise TypeError(f"cond tensor must have dtype torch.bool, but got {cond.dtype}")
- _check_with(error_type, cond._is_all_true().item(), message) # type: ignore[arg-type]
- # C++ equivalent: `TORCH_CHECK_TENSOR_ALL`
- def _check_tensor_all(cond, message=None): # noqa: F811
- r"""Throws error containing an optional message if the specified condition
- is False.
- Error type: ``RuntimeError``
- C++ equivalent: ``TORCH_CHECK_TENSOR_ALL``
- Args:
- cond (:class:`torch.Tensor`): Tensor of dtype ``torch.bool``. If any
- element is ``False``, throw error
- message (Callable, optional): Callable that returns either a string or
- an object that has a ``__str__()`` method to be used as the error
- message. Default: ``None``
- """
- _check_tensor_all_with(RuntimeError, cond, message)
- ################################################################################
- # Define numeric constants
- ################################################################################
- # For Python Array API (https://data-apis.org/array-api/latest/API_specification/constants.html) and
- # NumPy consistency (https://numpy.org/devdocs/reference/constants.html)
- from math import e, inf, nan, pi
- newaxis: None = None
- __all__.extend(["e", "pi", "nan", "inf", "newaxis"])
- ################################################################################
- # Define Storage and Tensor classes
- ################################################################################
- from torch._tensor import Tensor # usort: skip
- # needs to be after torch.Tensor is defined to avoid circular dependencies
- from torch import storage as storage # usort: skip
- from torch.storage import (
- _LegacyStorage,
- _StorageBase,
- _warn_typed_storage_removal,
- TypedStorage,
- UntypedStorage,
- )
- # NOTE: New <type>Storage classes should never be added. When adding a new
- # dtype, use torch.storage.TypedStorage directly.
- class ByteStorage(_LegacyStorage):
- @classproperty
- def dtype(self):
- _warn_typed_storage_removal(stacklevel=3)
- return self._dtype
- @classproperty
- def _dtype(self):
- return torch.uint8
- class DoubleStorage(_LegacyStorage):
- @classproperty
- def dtype(self):
- _warn_typed_storage_removal(stacklevel=3)
- return self._dtype
- @classproperty
- def _dtype(self):
- return torch.double
- class FloatStorage(_LegacyStorage):
- @classproperty
- def dtype(self):
- _warn_typed_storage_removal(stacklevel=3)
- return self._dtype
- @classproperty
- def _dtype(self):
- return torch.float
- class HalfStorage(_LegacyStorage):
- @classproperty
- def dtype(self):
- _warn_typed_storage_removal(stacklevel=3)
- return self._dtype
- @classproperty
- def _dtype(self):
- return torch.half
- class LongStorage(_LegacyStorage):
- @classproperty
- def dtype(self):
- _warn_typed_storage_removal(stacklevel=3)
- return self._dtype
- @classproperty
- def _dtype(self):
- return torch.long
- class IntStorage(_LegacyStorage):
- @classproperty
- def dtype(self):
- _warn_typed_storage_removal(stacklevel=3)
- return self._dtype
- @classproperty
- def _dtype(self):
- return torch.int
- class ShortStorage(_LegacyStorage):
- @classproperty
- def dtype(self):
- _warn_typed_storage_removal(stacklevel=3)
- return self._dtype
- @classproperty
- def _dtype(self):
- return torch.short
- class CharStorage(_LegacyStorage):
- @classproperty
- def dtype(self):
- _warn_typed_storage_removal(stacklevel=3)
- return self._dtype
- @classproperty
- def _dtype(self):
- return torch.int8
- class BoolStorage(_LegacyStorage):
- @classproperty
- def dtype(self):
- _warn_typed_storage_removal(stacklevel=3)
- return self._dtype
- @classproperty
- def _dtype(self):
- return torch.bool
- class BFloat16Storage(_LegacyStorage):
- @classproperty
- def dtype(self):
- _warn_typed_storage_removal(stacklevel=3)
- return self._dtype
- @classproperty
- def _dtype(self):
- return torch.bfloat16
- class ComplexDoubleStorage(_LegacyStorage):
- @classproperty
- def dtype(self):
- _warn_typed_storage_removal(stacklevel=3)
- return self._dtype
- @classproperty
- def _dtype(self):
- return torch.cdouble
- class ComplexFloatStorage(_LegacyStorage):
- @classproperty
- def dtype(self):
- _warn_typed_storage_removal(stacklevel=3)
- return self._dtype
- @classproperty
- def _dtype(self):
- return torch.cfloat
- class QUInt8Storage(_LegacyStorage):
- @classproperty
- def dtype(self):
- _warn_typed_storage_removal(stacklevel=3)
- return self._dtype
- @classproperty
- def _dtype(self):
- return torch.quint8
- class QInt8Storage(_LegacyStorage):
- @classproperty
- def dtype(self):
- _warn_typed_storage_removal(stacklevel=3)
- return self._dtype
- @classproperty
- def _dtype(self):
- return torch.qint8
- class QInt32Storage(_LegacyStorage):
- @classproperty
- def dtype(self):
- _warn_typed_storage_removal(stacklevel=3)
- return self._dtype
- @classproperty
- def _dtype(self):
- return torch.qint32
- class QUInt4x2Storage(_LegacyStorage):
- @classproperty
- def dtype(self):
- _warn_typed_storage_removal(stacklevel=3)
- return self._dtype
- @classproperty
- def _dtype(self):
- return torch.quint4x2
- class QUInt2x4Storage(_LegacyStorage):
- @classproperty
- def dtype(self):
- _warn_typed_storage_removal(stacklevel=3)
- return self._dtype
- @classproperty
- def _dtype(self):
- return torch.quint2x4
- _storage_classes: set[type[TypedStorage | UntypedStorage]] = {
- UntypedStorage,
- DoubleStorage,
- FloatStorage,
- LongStorage,
- IntStorage,
- ShortStorage,
- CharStorage,
- ByteStorage,
- HalfStorage,
- BoolStorage,
- QUInt8Storage,
- QInt8Storage,
- QInt32Storage,
- BFloat16Storage,
- ComplexFloatStorage,
- ComplexDoubleStorage,
- QUInt4x2Storage,
- QUInt2x4Storage,
- TypedStorage,
- }
- # The _tensor_classes set is initialized by the call to initialize_python_bindings.
- _tensor_classes: set[type["torch.Tensor"]] = set()
- # If you edit these imports, please update torch/__init__.py.in as well
- from torch import amp as amp, random as random, serialization as serialization
- from torch._tensor_str import set_printoptions
- from torch.amp import autocast, GradScaler
- from torch.random import get_rng_state, initial_seed, manual_seed, seed, set_rng_state
- from torch.serialization import load, save
- ################################################################################
- # Initialize extension
- ################################################################################
- # Shared memory manager needs to know the exact location of manager executable
- def _manager_path():
- if platform.system() == "Windows":
- return b""
- path = get_file_path("torch", "bin", "torch_shm_manager")
- prepare_multiprocessing_environment(get_file_path("torch"))
- if not os.path.exists(path):
- raise RuntimeError("Unable to find torch_shm_manager at " + path)
- return path.encode("utf-8")
- _C._initExtension(_manager_path())
- del _manager_path
- # Appease the type checker: it can't deal with direct setting of globals().
- # Note that we will see "too many" functions when reexporting this way; there
- # is not a good way to fix this problem. Perhaps, try to redesign VariableFunctions
- # so that this import is good enough
- if TYPE_CHECKING:
- # Some type signatures pulled in from _VariableFunctions here clash with
- # signatures already imported. For now these clashes are ignored; see
- # PR #43339 for details.
- from torch._C._VariableFunctions import * # type: ignore[assignment, misc] # noqa: F403
- # Fixup segment_reduce visibility
- _segment_reduce = segment_reduce
- del segment_reduce # noqa: F821
- # Ops not to be exposed in `torch` namespace,
- # mostly helper ops.
- PRIVATE_OPS = ("unique_dim",)
- __name, __obj = "", None
- for __name in dir(_C._VariableFunctions):
- if __name.startswith("__") or __name in PRIVATE_OPS:
- continue
- __obj = getattr(_C._VariableFunctions, __name)
- __obj.__module__ = __name__ # "torch"
- # Hide some APIs that should not be public
- if __name == "segment_reduce":
- # TODO: Once the undocumented FC window is passed, remove the line below
- globals()[__name] = __obj
- __name = "_" + __name
- globals()[__name] = __obj
- if not __name.startswith("_"):
- __all__.append(__name)
- del __name, __obj
- ################################################################################
- # Add torch.dtype instances to the public API
- ################################################################################
- import torch
- __all__.extend(
- name for name in dir(torch) if isinstance(getattr(torch, name), torch.dtype)
- )
- ################################################################################
- # Import TorchDynamo's lazy APIs to avoid circular dependencies
- ################################################################################
- # needs to be before from torch.functional import * to avoid circular dependencies
- from torch._compile import _disable_dynamo # usort: skip
- ################################################################################
- # Import interface functions defined in Python
- ################################################################################
- # needs to be after the above ATen bindings so we can overwrite from Python side
- from torch import _VF as _VF, functional as functional # usort: skip
- from torch.functional import * # usort: skip # noqa: F403
- ################################################################################
- # Remove unnecessary members
- ################################################################################
- del _StorageBase
- del _LegacyStorage
- ################################################################################
- # Define _assert
- ################################################################################
- # needs to be before the submodule imports to avoid circular dependencies
- def _assert(condition, message):
- r"""A wrapper around Python's assert which is symbolically traceable."""
- if type(condition) is not torch.Tensor and overrides.has_torch_function(
- (condition,)
- ):
- return overrides.handle_torch_function(
- _assert, (condition,), condition, message
- )
- if not condition:
- raise AssertionError(message)
- ################################################################################
- # Import most common subpackages
- ################################################################################
- # Use the redundant form so that type checkers know that these are a part of
- # the public API. The "regular" import lines are there solely for the runtime
- # side effect of adding to the imported module's members for other users.
- # needs to be before import torch.nn as nn to avoid circular dependencies
- from torch.autograd import ( # usort: skip
- enable_grad as enable_grad,
- inference_mode as inference_mode,
- no_grad as no_grad,
- set_grad_enabled as set_grad_enabled,
- )
- from torch import (
- __config__ as __config__,
- __future__ as __future__,
- _awaits as _awaits,
- accelerator as accelerator,
- autograd as autograd,
- backends as backends,
- cpu as cpu,
- cuda as cuda,
- distributed as distributed,
- distributions as distributions,
- fft as fft,
- futures as futures,
- hub as hub,
- jit as jit,
- linalg as linalg,
- mps as mps,
- mtia as mtia,
- multiprocessing as multiprocessing,
- nested as nested,
- nn as nn,
- optim as optim,
- overrides as overrides,
- profiler as profiler,
- sparse as sparse,
- special as special,
- testing as testing,
- types as types,
- utils as utils,
- version as version,
- xpu as xpu,
- )
- from torch.signal import windows as windows
- # Quantized, sparse, AO, etc. should be last to get imported, as nothing
- # is expected to depend on them.
- from torch import ao as ao # usort: skip
- # nn.quant* depends on ao -- so should be after those.
- import torch.nn.intrinsic
- import torch.nn.qat
- import torch.nn.quantizable
- import torch.nn.quantized
- _C._init_names(list(_storage_classes))
- # attach docstrings to torch and tensor functions
- from torch import _size_docs, _storage_docs, _tensor_docs, _torch_docs
- del _torch_docs, _tensor_docs, _storage_docs, _size_docs
- def compiled_with_cxx11_abi() -> builtins.bool:
- r"""Returns whether PyTorch was built with _GLIBCXX_USE_CXX11_ABI=1"""
- return True
- from torch import _library as _library, _ops as _ops
- # Import the ops and classes "namespace"
- from torch._ops import ops as ops # usort: skip
- from torch._classes import classes as classes # usort: skip
- sys.modules.setdefault(f"{__name__}.ops", ops)
- sys.modules.setdefault(f"{__name__}.classes", classes)
- # quantization depends on torch.fx and torch.ops
- # Import quantization
- from torch import quantization as quantization # usort: skip
- # Import the quasi random sampler
- from torch import quasirandom as quasirandom # usort: skip
- # If you are seeing this, it means that this call site was not checked if
- # the memory format could be preserved, and it was switched to old default
- # behaviour of contiguous
- legacy_contiguous_format = contiguous_format # defined by _C._initExtension()
- # Register fork handler to initialize OpenMP in child processes (see gh-28389)
- from torch.multiprocessing._atfork import register_after_fork
- register_after_fork(torch.get_num_threads)
- del register_after_fork
- # Import tools that require fully imported torch (for applying
- # torch.jit.script as a decorator, for instance):
- from torch._lobpcg import lobpcg as lobpcg
- # These were previously defined in native_functions.yaml and appeared on the
- # `torch` namespace, but we moved them to c10 dispatch to facilitate custom
- # class usage. We add these lines here to preserve backward compatibility.
- quantized_lstm = ops.aten.quantized_lstm
- quantized_gru = ops.aten.quantized_gru
- # Import experimental masked operations support. See
- # [RFC-0016](https://github.com/pytorch/rfcs/pull/27) for more
- # information.
- from torch import masked as masked
- # Import removed ops with error message about removal
- from torch._linalg_utils import ( # type: ignore[misc]
- _symeig as symeig,
- eig,
- lstsq,
- matrix_rank,
- solve,
- )
- from torch.utils.dlpack import from_dlpack, to_dlpack
- class _TorchCompileInductorWrapper:
- compiler_name = "inductor"
- def __init__(self, mode, options, dynamic):
- from torch._inductor.compiler_bisector import CompilerBisector
- self.config: dict[str, _Any] = {}
- self.dynamic = dynamic
- self.apply_mode(mode)
- self.apply_options(options)
- self.apply_options(CompilerBisector.get_config_change("inductor"))
- cuda_version = None
- if hasattr(torch, "version"):
- from torch.torch_version import TorchVersion
- cuda_version = TorchVersion(getattr(torch.version, "cuda", "0.0"))
- if self.config.get("triton.cudagraphs", False) and (
- (cuda_version and cuda_version < "12.6")
- or not profiler_allow_cudagraph_cupti_lazy_reinit_cuda12()
- ):
- os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1"
- # FIXME: CUDA Graph does not work well with CUPTI teardown.
- # 1) crashes on 1st lazy CUPTI re-init after teardown (CUDA 11)
- # 2) crashes on 2nd non-lazy CUPTI re-init after teardown (CUDA 12)
- # Workaround: turn off CUPTI teardown when using CUDA Graphs.
- os.environ["TEARDOWN_CUPTI"] = "0"
- def __eq__(self, other):
- return (
- isinstance(other, _TorchCompileInductorWrapper)
- and self.config == other.config
- and self.dynamic == other.dynamic
- )
- def apply_mode(self, mode: str | None):
- if mode and mode != "default":
- from torch._inductor import list_mode_options
- self.apply_options(list_mode_options(mode, self.dynamic))
- def apply_options(self, options: dict[str, _Any] | None):
- if not options:
- return
- from torch._inductor import config
- current_config: dict[str, _Any] = config.get_config_copy()
- for key, val in options.items():
- attr_name = key.replace("-", "_")
- if attr_name not in current_config:
- raise RuntimeError(
- f"Unexpected optimization option {key}, known options are {list(current_config.keys())}"
- )
- attr_type = config.get_type(attr_name) # type: ignore[attr-defined]
- # Subscriptable generic types don't support isinstance so skip the type
- # check. There doesn't seem to be a good way of checking membership without
- # 3rd party libraries.
- if _get_origin(attr_type) is None:
- if not isinstance(val, attr_type):
- val_type_str = type(val).__name__
- expected_type_str = type(current_config[attr_name]).__name__
- raise RuntimeError(
- f"Unexpected type of attr {key}, got {val_type_str} should be {expected_type_str}"
- )
- self.config[attr_name] = val
- def __call__(self, model_, inputs_):
- from torch._inductor.compile_fx import compile_fx
- return compile_fx(model_, inputs_, config_patches=self.config)
- def get_compiler_config(self):
- from torch._inductor.compile_fx import get_patched_config_dict
- return get_patched_config_dict(config_patches=self.config)
- def reset(self):
- from torch._inductor import config
- if "triton.cudagraphs" in self.config or config.triton.cudagraphs:
- if self.config.get("triton.cudagraphs", True):
- from torch._inductor.cudagraph_trees import reset_cudagraph_trees
- reset_cudagraph_trees()
- class _TorchCompileAOTInductorWrapper(_TorchCompileInductorWrapper):
- compiler_name = "aotinductor"
- def __init__(self, mode, options, dynamic):
- super().__init__(mode, options, dynamic)
- self.apply_options({"cpp_wrapper": True})
- self.apply_options({"aot_inductor.package": True})
- def __call__(self, model_, inputs_):
- from contextlib import nullcontext
- from unittest import mock
- from torch._guards import detect_fake_mode
- from torch._inductor.virtualized import V
- fake_mode = detect_fake_mode(inputs_)
- ctx = (
- mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
- if fake_mode
- else nullcontext()
- )
- with (
- V.set_aot_compilation(True),
- ctx,
- torch._inductor.config.patch("enable_autograd_for_aot", True),
- ):
- return super().__call__(model_, inputs_)
- class _TorchCompileWrapper:
- def __init__(self, backend, mode, options, dynamic):
- from torch._dynamo.backends.registry import lookup_backend
- if isinstance(backend, str):
- self.compiler_name = backend
- elif hasattr(backend, "__name__"):
- self.compiler_name = backend.__name__
- else:
- self.compiler_name = str(backend)
- self.dynamic = dynamic
- self.compiler_fn = lookup_backend(backend)
- self.kwargs = {}
- # only pass the args if they non-empty
- if mode and mode != "default":
- self.kwargs["mode"] = mode
- if options:
- self.kwargs["options"] = options
- def __eq__(self, other):
- return (
- isinstance(other, _TorchCompileWrapper)
- and self.compiler_fn == other.compiler_fn
- and self.kwargs == other.kwargs
- and self.dynamic == other.dynamic
- )
- def __call__(self, model_, inputs_):
- return self.compiler_fn(model_, inputs_, **self.kwargs)
- def reset(self):
- if hasattr(self.compiler_fn, "reset"):
- self.compiler_fn.reset()
- _InputT = _ParamSpec("_InputT")
- _RetT = _TypeVar("_RetT")
- @_overload
- def compile(
- model: _Callable[_InputT, _RetT],
- *,
- fullgraph: builtins.bool = False,
- dynamic: builtins.bool | None = None,
- backend: str | _Callable = "inductor",
- mode: str | None = None,
- options: dict[str, str | builtins.int | builtins.bool | _Callable] | None = None,
- disable: builtins.bool = False,
- ) -> _Callable[_InputT, _RetT]: ...
- @_overload
- def compile(
- model: None = None,
- *,
- fullgraph: builtins.bool = False,
- dynamic: builtins.bool | None = None,
- backend: str | _Callable = "inductor",
- mode: str | None = None,
- options: dict[str, str | builtins.int | builtins.bool | _Callable] | None = None,
- disable: builtins.bool = False,
- ) -> _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]]: ...
- def compile(
- model: _Callable[_InputT, _RetT] | None = None,
- *,
- fullgraph: builtins.bool = False,
- dynamic: builtins.bool | None = None,
- backend: str | _Callable = "inductor",
- mode: str | None = None,
- options: dict[str, str | builtins.int | builtins.bool | _Callable] | None = None,
- disable: builtins.bool = False,
- ) -> (
- _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]]
- | _Callable[_InputT, _RetT]
- ):
- """
- Optimizes given model/function using TorchDynamo and specified backend.
- If you are compiling an :class:`torch.nn.Module`, you can also use :meth:`torch.nn.Module.compile`
- to compile the module inplace without changing its structure.
- Concretely, for every frame executed within the compiled region, we will attempt
- to compile it and cache the compiled result on the code object for future
- use. A single frame may be compiled multiple times if previous compiled
- results are not applicable for subsequent calls (this is called a "guard
- failure"), you can use TORCH_LOGS=guards to debug these situations.
- Multiple compiled results can be associated with a frame up to
- ``torch._dynamo.config.recompile_limit``, which defaults to 8; at which
- point we will fall back to eager. Note that compile caches are per
- *code object*, not frame; if you dynamically create multiple copies of a
- function, they will all share the same code cache.
- Args:
- model (Callable or None): Module/function to optimize
- fullgraph (bool): If False (default), torch.compile attempts to discover compilable regions
- in the function that it will optimize. If True, then we require that the entire function be
- capturable into a single graph. If this is not possible (that is, if there are graph breaks),
- then this will raise an error. This also opts into unbacked semantics, notably it will turn on
- capture_scalar_outputs and capture_dynamic_output_shape_ops on by default.
- dynamic (bool or None): Use dynamic shape tracing. When this is True, we will up-front attempt
- to generate a kernel that is as dynamic as possible to avoid recompilations when
- sizes change. This may not always work as some operations/optimizations will
- force specialization; use TORCH_LOGS=dynamic to debug overspecialization.
- When this is False, we will NEVER generate dynamic kernels, we will always specialize.
- By default (None), we automatically detect if dynamism has occurred and compile a more
- dynamic kernel upon recompile.
- backend (str or Callable): backend to be used
- - "inductor" is the default backend, which is a good balance between performance and overhead
- - Non experimental in-tree backends can be seen with `torch._dynamo.list_backends()`
- - Experimental or debug in-tree backends can be seen with `torch._dynamo.list_backends(None)`
- - To register an out-of-tree custom backend:
- https://docs.pytorch.org/docs/main/user_guide/torch_compiler/torch.compiler_custom_backends.html#registering-custom-backends
- mode (str): Can be either "default", "reduce-overhead", "max-autotune" or "max-autotune-no-cudagraphs"
- - "default" is the default mode, which is a good balance between performance and overhead
- - "reduce-overhead" is a mode that reduces the overhead of python with CUDA graphs,
- useful for small batches. Reduction of overhead can come at the cost of more memory
- usage, as we will cache the workspace memory required for the invocation so that we
- do not have to reallocate it on subsequent runs. Reduction of overhead is not guaranteed
- to work; today, we only reduce overhead for CUDA only graphs which do not mutate inputs.
- There are other circumstances where CUDA graphs are not applicable; use TORCH_LOGS=perf_hints
- to debug.
- - "max-autotune" is a mode that leverages Triton or template based matrix multiplications
- on supported devices and Triton based convolutions on GPU.
- It enables CUDA graphs by default on GPU.
- - "max-autotune-no-cudagraphs" is a mode similar to "max-autotune" but without CUDA graphs
- - To see the exact configs that each mode sets you can call `torch._inductor.list_mode_options()`
- options (dict): A dictionary of options to pass to the backend. Some notable ones to try out are
- - `epilogue_fusion` which fuses pointwise ops into templates. Requires `max_autotune` to also be set
- - `max_autotune` which will profile to pick the best matmul configuration
- - `fallback_random` which is useful when debugging accuracy issues
- - `shape_padding` which pads matrix shapes to better align loads on GPUs especially for tensor cores
- - `triton.cudagraphs` which will reduce the overhead of python with CUDA graphs
- - `trace.enabled` which is the most useful debugging flag to turn on
- - `trace.graph_diagram` which will show you a picture of your graph after fusion
- - `guard_filter_fn` that controls which dynamo guards are saved with compilations.
- This is an unsafe feature and there is no backward compatibility guarantee provided
- for dynamo guards as data types.
- For stable helper functions to use, see the documentations in `torch.compiler`, for example:
- - `torch.compiler.skip_guard_on_inbuilt_nn_modules_unsafe`
- - `torch.compiler.skip_guard_on_all_nn_modules_unsafe`
- - `torch.compiler.keep_tensor_guards_unsafe`
- - For inductor you can see the full list of configs that it supports by calling `torch._inductor.list_options()`
- disable (bool): Turn torch.compile() into a no-op for testing
- Example::
- @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
- def foo(x):
- return torch.sin(x) + torch.cos(x)
- """
- import sysconfig
- _C._log_api_usage_once("torch.compile")
- if sys.version_info >= (3, 15):
- raise RuntimeError("torch.compile is not supported on Python 3.15+")
- elif sysconfig.get_config_var("Py_GIL_DISABLED") == 1 and sys.version_info < (
- 3,
- 13,
- 3,
- ):
- raise RuntimeError(
- "torch.compile is not supported on Python < 3.13.3 built with GIL disabled. "
- "Please use Python 3.13.3+."
- )
- # Decorator mode
- if model is None:
- def fn(model: _Callable[_InputT, _RetT]) -> _Callable[_InputT, _RetT]:
- if model is None:
- raise RuntimeError("Model can't be None")
- return compile( # pyrefly: ignore # no-matching-overload
- model,
- fullgraph=fullgraph,
- dynamic=dynamic,
- backend=backend,
- mode=mode,
- options=options,
- disable=disable,
- )
- return fn
- if mode is not None and options is not None:
- raise RuntimeError(
- "Either mode or options can be specified, but both can't be specified at the same time."
- )
- if mode is None and options is None:
- mode = "default"
- from torch._inductor.compiler_bisector import CompilerBisector
- if bisect_backend := CompilerBisector.get_backend():
- import torch._inductor.config as inductor_config
- # don't override the backend for use cases like vllm
- # which leverages their custom backend.
- if not (
- inductor_config.test_configs.bisect_keep_custom_backend_for_inductor
- and bisect_backend == "inductor"
- and not isinstance(backend, str)
- ):
- backend = bisect_backend
- guard_filter_fn = None
- use_aoti = False
- if options and isinstance(options, dict):
- guard_filter_fn = options.pop("guard_filter_fn", None)
- use_aoti = options.pop("use_aoti", False)
- if torch.compiler.is_exporting():
- from torch._higher_order_ops.utils import _in_hop_compile
- if not _in_hop_compile():
- warnings.warn(
- "torch.compile is ignored when called inside torch.export region",
- stacklevel=2,
- )
- # torch.compile is a no-op when inside torch.export region
- return model
- if backend == "inductor":
- if use_aoti:
- backend = _TorchCompileAOTInductorWrapper(mode, options, dynamic)
- else:
- backend = _TorchCompileInductorWrapper(mode, options, dynamic)
- else:
- backend = _TorchCompileWrapper(backend, mode, options, dynamic)
- return torch._dynamo.optimize(
- backend=backend,
- nopython=fullgraph,
- dynamic=dynamic,
- disable=disable,
- guard_filter_fn=guard_filter_fn,
- )(model) # type: ignore[return-value]
- def _register_device_module(device_type, module):
- r"""Register an external runtime module of the specific :attr:`device_type`
- supported by torch.
- After the :attr:`module` is registered correctly, the user can refer
- the external runtime module as part of torch with attribute torch.xxx.
- """
- # Make sure the device_type represent a supported device type for torch.
- device_type = torch.device(device_type).type
- m = sys.modules[__name__]
- if hasattr(m, device_type):
- raise RuntimeError(
- f"The runtime module of '{device_type}' has already "
- f"been registered with '{getattr(m, device_type)}'"
- )
- setattr(m, device_type, module)
- torch_module_name = ".".join([__name__, device_type])
- sys.modules[torch_module_name] = module
- from torch import (
- export as export,
- func as func,
- library as library,
- return_types as return_types,
- )
- from torch._higher_order_ops import cond as cond, while_loop as while_loop
- from torch.func import vmap as vmap
- if not TYPE_CHECKING:
- # register python metas for distributed ops
- # Only import if distributed is available (USE_DISTRIBUTED=1)
- if hasattr(torch._C, "_c10d_init"):
- import torch.distributed._meta_registrations as coll_meta_registrations
- del coll_meta_registrations
- from torch import _meta_registrations
- # Enable CUDA Sanitizer
- if "TORCH_CUDA_SANITIZER" in os.environ:
- import torch.cuda._sanitizer as csan
- csan.enable_cuda_sanitizer()
- # Populate magic methods on SymInt and SymFloat
- import torch.fx.experimental.sym_node
- from torch import fx as fx
- # Register MPS specific decomps
- torch.backends.mps._init()
- from torch import compiler as compiler
- class _TritonLibrary:
- lib = torch.library.Library("triton", "DEF")
- ops_table: dict[tuple[str, str], _Callable] = {}
- @classmethod
- def registerOp(cls, op_key, full_schema, op_impl, dispatch_key):
- if (op_key, dispatch_key) not in cls.ops_table:
- cls.lib.define(full_schema)
- cls.lib.impl("triton::" + op_key, op_impl, dispatch_key)
- cls.ops_table[(op_key, dispatch_key)] = op_impl
- return cls.ops_table[(op_key, dispatch_key)]
- # Deprecated attributes
- _deprecated_attrs = {
- "has_mps": torch.backends.mps.is_built,
- "has_cuda": torch.backends.cuda.is_built,
- "has_cudnn": torch.backends.cudnn.is_available,
- "has_mkldnn": torch.backends.mkldnn.is_available,
- }
- if TYPE_CHECKING:
- # Import the following modules during type checking to enable code intelligence features,
- # such as auto-completion in tools like pylance, even when these modules are not explicitly
- # imported in user code.
- from torch import (
- _dynamo as _dynamo,
- _inductor as _inductor,
- _subclasses as _subclasses,
- onnx as onnx,
- )
- else:
- _lazy_modules = {
- "_dynamo",
- "_inductor",
- "_export",
- # ONNX must be imported after _dynamo, _ops, _subclasses, fx, func and jit
- "onnx",
- }
- def __getattr__(name):
- # Deprecated attrs
- replacement = _deprecated_attrs.get(name)
- if replacement is not None:
- import warnings
- warnings.warn(
- f"'{name}' is deprecated, please use '{replacement.__module__}.{replacement.__name__}()'",
- stacklevel=2,
- )
- return replacement()
- # Lazy modules
- if name in _lazy_modules:
- return importlib.import_module(f".{name}", __name__)
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
- @functools.cache
- def get_device_module(device: torch.device | str | None = None):
- """
- Returns the module associated with a given device(e.g., torch.device('cuda'), "mtia:0", "xpu", ...).
- If no device is given, return the module for the current accelerator or CPU if none is present.
- """
- if isinstance(device, torch.device):
- device_module_name = device.type
- elif isinstance(device, str):
- device_module_name = torch.device(device).type
- elif device is None:
- # Using default accelerator type. If no accelerator is available, it automatically returns CPU device.
- device_module_name = torch._C._get_accelerator().type
- else:
- raise RuntimeError(
- f"Invalid value of device '{device}', expect torch.device, str, or None"
- )
- device_module = getattr(torch, device_module_name, None)
- if device_module is None:
- raise RuntimeError(
- f"Device '{device_module_name}' does not have a corresponding module registered as 'torch.{device_module_name}'."
- )
- return device_module
- def _constrain_as_size(
- symbol,
- min: builtins.int | None = None,
- max: builtins.int | None = None,
- ):
- """
- This indicates that a given int is size-like, and can be used in any context where a size is expected.
- You will typically use this when reading out integers from Tensors, e.g., max.item() or lengths.tolist()
- which then need to be used as tensor constructors. Providing these assertions to PyTorch can help resolve
- GuardOnDataDependentSymNode errors upon export, since we cannot guard on unbacked SymInts.
- This function has unusual semantics in some circumstances in framework
- code, we will treat this int as >= 2 (when we do a size-oblivious guard).
- This makes it easier to use the unbacked int in size contexts,
- as we will often attempt to guard on a size being zero/one
- (e.g., when computing the contiguity of a tensor, or testing if
- broadcasting can occur), which will not work on unbacked SymInts.
- However, if we conservatively assume that the size is not zero/one, we will
- end up with a graph that will still work even if the size is zero/one.
- For more details, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit
- ```
- """
- torch.sym_constrain_range_for_size(symbol, min=min, max=max)
- from torch import _logging
- _logging._init_logs()
- def _import_device_backends():
- """
- Leverage the Python plugin mechanism to load out-of-the-tree device extensions.
- See this RFC: https://github.com/pytorch/pytorch/issues/122468
- """
- from importlib.metadata import entry_points
- group_name = "torch.backends"
- backend_extensions = entry_points(group=group_name)
- for backend_extension in backend_extensions:
- try:
- # Load the extension
- entrypoint = backend_extension.load()
- # Call the entrypoint
- entrypoint()
- except Exception as err:
- raise RuntimeError(
- f"Failed to load the backend extension: {backend_extension.name}. "
- f"You can disable extension auto-loading with TORCH_DEVICE_BACKEND_AUTOLOAD=0."
- ) from err
- def _is_device_backend_autoload_enabled() -> builtins.bool:
- """
- Whether autoloading out-of-the-tree device extensions is enabled.
- The switch depends on the value of the environment variable
- `TORCH_DEVICE_BACKEND_AUTOLOAD`.
- Returns:
- bool: Whether to enable autoloading the extensions. Enabled by default.
- Examples:
- >>> torch._is_device_backend_autoload_enabled()
- True
- """
- # enabled by default
- return os.getenv("TORCH_DEVICE_BACKEND_AUTOLOAD", "1") == "1"
- def _as_tensor_fullprec(t):
- """
- Like torch.as_tensor, but when given Python data types it will keep
- them in full precision. Used for calling convention for Dynamo.
- Python scalars (float, int) are always created on CPU to avoid being
- affected by DeviceContext.
- """
- ty = type(t)
- if ty is builtins.float:
- return torch.as_tensor(t, dtype=torch.float64, device="cpu")
- elif ty is builtins.int:
- return torch.as_tensor(t, dtype=torch.int64, device="cpu")
- else:
- return torch.as_tensor(t)
- # `_import_device_backends` should be kept at the end to ensure
- # all the other functions in this module that may be accessed by
- # an autoloaded backend are defined
- if _is_device_backend_autoload_enabled():
- _import_device_backends()
|