| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769377037713772377337743775377637773778377937803781378237833784378537863787378837893790379137923793379437953796379737983799380038013802380338043805380638073808380938103811381238133814381538163817381838193820382138223823382438253826382738283829383038313832383338343835383638373838383938403841384238433844384538463847384838493850385138523853385438553856385738583859386038613862386338643865386638673868386938703871387238733874387538763877387838793880388138823883388438853886388738883889389038913892389338943895389638973898389939003901390239033904390539063907390839093910391139123913391439153916391739183919392039213922392339243925392639273928392939303931393239333934393539363937393839393940394139423943394439453946394739483949395039513952395339543955395639573958395939603961396239633964396539663967396839693970397139723973397439753976397739783979398039813982398339843985398639873988398939903991399239933994399539963997399839994000400140024003400440054006400740084009401040114012401340144015401640174018401940204021402240234024402540264027402840294030403140324033403440354036403740384039404040414042404340444045404640474048404940504051405240534054405540564057405840594060406140624063406440654066406740684069407040714072407340744075407640774078407940804081408240834084408540864087408840894090409140924093409440954096409740984099410041014102410341044105410641074108410941104111411241134114411541164117411841194120412141224123412441254126412741284129413041314132413341344135413641374138413941404141414241434144414541464147414841494150415141524153415441554156415741584159416041614162416341644165416641674168416941704171417241734174417541764177417841794180418141824183418441854186418741884189419041914192419341944195419641974198419942004201420242034204420542064207420842094210421142124213421442154216421742184219422042214222422342244225422642274228422942304231423242334234423542364237423842394240424142424243424442454246424742484249425042514252425342544255425642574258425942604261426242634264426542664267426842694270427142724273427442754276427742784279428042814282428342844285428642874288428942904291429242934294429542964297429842994300430143024303430443054306430743084309431043114312431343144315431643174318431943204321432243234324432543264327432843294330433143324333433443354336433743384339434043414342434343444345434643474348434943504351435243534354435543564357435843594360436143624363436443654366436743684369437043714372437343744375 |
- """
- This module contains classes and utilities for building variable trackers in Dynamo.
- Variable trackers are used to convert Python values into symbolic representations
- that can be traced and transformed during graph capture.
- The key classes are:
- - VariableBuilder: Handles source-tracked objects that need guards and proper
- reconstruction in the output graph. Used for inputs, module attributes, etc.
- - SourcelessBuilder: Handles ephemeral objects created during tracing that don't
- need source tracking or guards. Used for temporary lists, intermediate values, etc.
- Variable trackers enable Dynamo to track the flow of values through the program,
- maintain guards for dynamic properties, and reconstruct values in the output graph.
- The builders in this module handle converting Python values into appropriate
- VariableTracker instances based on their type and usage context.
- """
- import abc
- import collections
- import contextlib
- import copy
- import dataclasses
- import enum
- import functools
- import inspect
- import itertools
- import logging
- import math
- import operator
- import random
- import re
- import sys
- import time
- import types
- import weakref
- from collections.abc import Callable, MutableMapping
- from types import ModuleType
- from typing import Any, NamedTuple, NoReturn, Optional, overload, TYPE_CHECKING, Union
- import sympy
- import torch
- from torch import SymInt
- from torch._dispatch.python import enable_python_dispatcher
- from torch._dynamo.graph_bytecode_inputs import (
- get_external_object_by_index,
- register_user_object,
- )
- from torch._dynamo.utils import (
- get_metrics_context,
- is_int_specialization_case,
- is_torch_sym,
- set_feature_use,
- )
- from torch._guards import TracingContext
- from torch._higher_order_ops.flat_apply import flat_apply
- from torch._higher_order_ops.torchbind import call_torchbind
- from torch._library.opaque_object import (
- is_opaque_reference_type,
- is_opaque_type,
- is_opaque_value_type,
- should_hoist,
- )
- from torch._ops import HigherOrderOperator, OpOverload, OpOverloadPacket
- from torch._subclasses.fake_tensor import (
- FakeTensor,
- FakeTensorMode,
- is_fake,
- maybe_get_fake_mode,
- )
- from torch._subclasses.meta_utils import is_sparse_any, safe_grad
- from torch._utils_internal import justknobs_check
- from torch.fx.experimental._backward_state import BackwardState
- from torch.fx.experimental._dynamism import normalize_source_name
- from torch.fx.experimental.sym_node import _DynamicScalar, DynamicInt
- from torch.fx.experimental.symbolic_shapes import (
- _constrain_range_for_size,
- _nested_int_aware_sort,
- DimDynamic,
- RelaxedUnspecConstraint,
- StatefulSymbolicContext,
- SubclassSymbolicContext,
- SymbolicContext,
- SymIntSymbolicContext,
- TrackedFake,
- )
- from torch.fx.immutable_collections import immutable_dict, immutable_list
- from torch.nn.utils._expanded_weights import ExpandedWeight
- from torch.utils._ordered_set import OrderedSet
- from torch.utils._python_dispatch import (
- is_traceable_wrapper_subclass,
- is_traceable_wrapper_subclass_type,
- )
- from torch.utils._sympy.value_ranges import ValueRanges
- from torch.utils.weak import TensorWeakRef
- from .. import config, graph_break_hints, mutation_guard, replay_record, trace_rules
- from ..device_interface import get_registered_device_interfaces
- from ..exc import InternalTorchDynamoError, raise_observed_exception, unimplemented
- from ..guards import GuardBuilder, install_guard, make_dupe_guard
- from ..pgo import (
- auto_dynamic,
- auto_unset,
- FrameStateSizeEntry,
- InferStride,
- process_automatic_dynamic,
- )
- from ..side_effects import SideEffects
- from ..source import (
- AttrProxySource,
- AttrSource,
- CallMethodItemSource,
- ChainedSource,
- ConstDictKeySource,
- ConvertIntSource,
- DictGetItemSource,
- DictSubclassGetItemSource,
- DynamicScalarSource,
- FloatTensorSource,
- GetItemSource,
- GradSource,
- is_constant_source,
- is_from_closure_source,
- is_from_global_source,
- is_from_nonlocal_source,
- is_from_optimizer_source,
- is_from_unspecialized_nn_module_source,
- ListGetItemSource,
- LocalSource,
- NonSerializableSetGetItemSource,
- NumpyTensorSource,
- OptimizerSource,
- RandomValueSource,
- SkipGuardSource,
- Source,
- SubclassAttrListSource,
- TupleIteratorGetItemSource,
- UnspecializedBuiltinNNModuleSource,
- UnspecializedNNModuleSource,
- )
- from ..utils import (
- _extract_tensor_dict,
- build_checkpoint_variable,
- build_invoke_subgraph_variable,
- clone_input,
- common_constant_types,
- dict_keys,
- get_fake_value,
- get_items_from_dict,
- get_locals_to_steal,
- get_static_address_type,
- is_frozen_dataclass,
- is_function,
- is_function_or_wrapper,
- is_invoke_subgraph,
- is_lru_cache_wrapped_function,
- is_namedtuple,
- is_parameter_freezing,
- is_typing,
- is_utils_checkpoint,
- is_wrapper_or_member_descriptor,
- istype,
- namedtuple_fields,
- odict_values,
- proxy_args_kwargs,
- range_iterator,
- set_example_value,
- tensor_always_has_static_shape,
- tuple_iterator,
- tuple_iterator_getitem,
- tuple_iterator_len,
- unwrap_with_attr_name_if_wrapper,
- wrap_fake_exception,
- )
- from .base import (
- AttributeMutationNew,
- typestr,
- ValueMutationExisting,
- ValueMutationNew,
- VariableTracker,
- VariableTrackerMeta,
- )
- from .builtin import BuiltinVariable
- from .constant import ConstantVariable, EnumVariable
- from .ctx_manager import (
- AutocastModeVariable,
- DynamoConfigPatchVariable,
- ErrorOnGraphBreakVariable,
- NullContextVariable,
- PreserveVersionContextVariable,
- )
- from .dicts import (
- ConstDictVariable,
- DefaultDictVariable,
- DictKeySetVariable,
- FrozensetVariable,
- MappingProxyVariable,
- OrderedSetClassVariable,
- OrderedSetVariable,
- SetVariable,
- )
- from .distributed import (
- DeviceMeshVariable,
- PlacementClassVariable,
- PlacementVariable,
- ProcessGroupVariable,
- WorldMetaClassVariable,
- )
- from .functions import (
- BuiltinMethodVariable,
- CollectionsNamedTupleFunction,
- CollectiveFunctionRewriteVariable,
- CreateTMADescriptorExperimentalVariable,
- CreateTMADescriptorStableVariable,
- FunctoolsPartialVariable,
- FunctoolsWrapsVariable,
- SysFunctionVariable,
- TritonKernelVariable,
- TritonSetAllocatorSkipVariable,
- UserFunctionVariable,
- UserMethodVariable,
- WrapperUserFunctionVariable,
- )
- from .higher_order_ops import (
- LocalMapWrappedHigherOrderVariable,
- TorchHigherOrderOperatorVariable,
- )
- from .iter import ItertoolsVariable
- from .lazy import LazyConstantVariable, LazyVariableTracker
- from .lists import (
- BaseListVariable,
- ListIteratorVariable,
- ListVariable,
- NamedTupleVariable,
- RangeVariable,
- SizeVariable,
- SliceVariable,
- TupleIteratorVariable,
- TupleVariable,
- )
- from .misc import (
- AutogradEngineVariable,
- AutogradFunctionContextVariable,
- AutogradFunctionVariable,
- ComptimeVariable,
- ConstantLikeVariable,
- DebuggingVariable,
- DelayGraphBreakVariable,
- GetAttrVariable,
- GetSetDescriptorVariable,
- IgnoredFunctionVariable,
- LambdaVariable,
- LoggingLoggerVariable,
- MethodWrapperVariable,
- NumpyDTypeVariable,
- NumpyVariable,
- ObjectVariable,
- PythonModuleVariable,
- RandomClassVariable,
- RandomVariable,
- SavedTensorBox,
- TorchVersionVariable,
- TypingVariable,
- WeakRefVariable,
- )
- from .nn_module import (
- FSDPManagedNNModuleVariable,
- UnspecializedBuiltinNNModuleVariable,
- UnspecializedNNModuleVariable,
- )
- from .optimizer import OptimizerVariable
- from .script_object import OpaqueObjectClassVariable, TorchScriptObjectVariable
- from .sdpa import SDPAParamsVariable
- from .streams import EventVariable, StreamContextVariable, StreamVariable
- from .tensor import (
- NumpyNdarrayVariable,
- supported_const_comparison_op_values,
- SymNodeVariable,
- TensorSubclassVariable,
- TensorVariable,
- UnspecializedPythonVariable,
- )
- from .torch import (
- DispatchKeySetVariable,
- FuncTorchInterpreterVariable,
- TorchCtxManagerClassVariable,
- TorchInGraphFunctionVariable,
- )
- from .torch_function import (
- TensorWithTFOverrideVariable,
- torch_function_mode_stack_state_mgr,
- TorchFunctionModeVariable,
- )
- from .user_defined import (
- FrozenDataClassVariable,
- InspectVariable,
- IntWrapperVariable,
- KeyedJaggedTensorVariable,
- MutableMappingVariable,
- SourcelessGraphModuleVariable,
- UserDefinedClassVariable,
- UserDefinedDictVariable,
- UserDefinedEnumClassVariable,
- UserDefinedExceptionClassVariable,
- UserDefinedListVariable,
- UserDefinedObjectVariable,
- UserDefinedSetVariable,
- UserDefinedTupleVariable,
- )
- try:
- import numpy as np
- except ModuleNotFoundError:
- np: ModuleType = None # type: ignore[assignment]
- if TYPE_CHECKING:
- from torch._dynamo.codegen import PyCodegen
- from torch._dynamo.symbolic_convert import (
- InstructionTranslator,
- InstructionTranslatorBase,
- )
- log = logging.getLogger(__name__)
- static_inputs_log = torch._logging.getArtifactLogger(
- __name__, "cudagraph_static_inputs"
- )
- from typing import TypeVar
- # Placeholder for a VariableTracker to be used in proxy
- # creation so that we don't type erase
- VTTypeAlias = TypeVar("VTTypeAlias")
- T = TypeVar("T")
- DimList = list
- def safe_has_grad(t: object) -> bool:
- with torch._logging.hide_warnings(torch._logging._internal.safe_grad_filter):
- return hasattr(t, "grad")
- class _missing:
- pass
- @dataclasses.dataclass
- class GraphArg:
- source: Source | None
- # TODO: storing a SymInt here but not a FakeTensor is a pretty strange
- # thing to do. Probably should have example (which stores an int) and
- # fake_example
- _example: Any
- # When True, this indicates that this GraphArg is a Python quantity (e.g.,
- # a float or int) which we pass to the FX graph as a Tensor. This
- # controls how we codegen calls into the Dynamo graph: we will call
- # torch.as_tensor on the quantity before passing it in.
- #
- # Note that we typically do not pass dynamic integers as tensors, because
- # they will most frequently just be used for size computation. But this
- # is a policy decision that we can change our mind on; in particular, when
- # an int comes from a random number generator (e.g., random.randint), we
- # DO pass it as a tensor.
- #
- # It's also worth noting that our current tracing rules for
- # pass_arg_as_tensor as subtly broken: we just pun the variable as a
- # 0d scalar Tensor and pray that the semantics are the same. Which they
- # often are, but not necessarily. ezyang(May 2024) plans to fix this
- # soon.
- pass_arg_as_tensor: bool
- fake_tensor: torch._subclasses.fake_tensor.FakeTensor | None
- # UnspecializedPythonVariable often masquerades as a tensor.
- # We MUST NOT generate shape guard code
- # that actually tries to access tensor properties on these values.
- # is_tensor lets us tell if this graph arg actually is a tensor
- # or not.
- is_tensor: bool = True
- # Sometimes, the Tensor we pass to example is freshly allocated (smh).
- # Then we cannot only keep a weak reference to it. This lets you
- # stash a strong reference too.
- example_strong_ref: torch.Tensor | torch.SymInt | None = None
- def __setattr__(self, name: str, value: Any) -> None:
- # Use object.__setattr__ to bypass Dynamo's STORE_ATTR interception.
- # This is needed because when PYTORCH_TEST_WITH_DYNAMO=1, even internal
- # GraphArg creation can be traced, and with replay_side_effects=False,
- # normal STORE_ATTR bytecode only records mutations without applying them.
- object.__setattr__(self, name, value)
- @property
- def example(self) -> torch.Tensor | torch.SymInt | BackwardState | None:
- if isinstance(self._example, TensorWeakRef):
- r = self._example()
- assert r is not None
- return r
- else:
- return self._example
- def __post_init__(self) -> None:
- if isinstance(self._example, torch.Tensor):
- self._example = TensorWeakRef(self._example)
- assert is_fake(self.fake_tensor)
- def reconstruct(self, codegen: "PyCodegen") -> None:
- codegen(self.source)
- def erase(self) -> None:
- self._example = None
- self.example_strong_ref = None
- def __eq__(self, other: object) -> bool:
- if not isinstance(other, GraphArg):
- return False
- if self.source is None:
- return other.source is None
- else:
- if other.source is None:
- return False
- return self.source.name == other.source.name
- class BackwardStateGraphArg(GraphArg):
- def __init__(self) -> None:
- super().__init__(
- source=None,
- _example=BackwardState(),
- pass_arg_as_tensor=False,
- fake_tensor=None,
- is_tensor=False,
- )
- def reconstruct(self, codegen: "PyCodegen") -> None:
- assert codegen.tx.output.backward_state_var
- codegen.add_push_null(
- lambda: codegen.load_import_from(BackwardState.__module__, "BackwardState")
- )
- codegen.call_function(0, False)
- codegen.dup_top()
- codegen.store(codegen.tx.output.backward_state_var)
- # All class-based iterators in itertools
- # NOTE: use id() because some objects are not hashable, it will raise error during lookup
- ITERTOOLS_TYPE_IDS: frozenset[int] = frozenset(
- id(member)
- for name, member in vars(itertools).items()
- if not name.startswith("_") and inspect.isclass(member)
- )
- # Will be updated later in substitute_in_graph in torch/_dynamo/polyfills/itertools.py
- ITERTOOLS_POLYFILLED_TYPE_IDS: set[int] = set()
- # Capture fn pointer at import time
- # This is to guard against trying to mark the iterated tensors
- # as static in case user overrides fn ptr
- og_module_named_buffers_fn_ptr = torch.nn.Module.named_buffers
- og_module_named_parameters_fn_ptr = torch.nn.Module.named_parameters
- class VariableBuilder:
- """Wrap a python value in a VariableTracker() instance"""
- def __init__(
- self,
- tx: "InstructionTranslator",
- source: Source,
- allow_lazy_constant: bool = True,
- ) -> None:
- assert source is not None, (
- "Consider SourcelessBuilder for ephemeral objects, usually objects created locally."
- )
- assert TracingContext.try_get() is not None, "Expected active TracingContext"
- super().__init__()
- self.tx = tx
- self.source = source
- self.name = source.name
- # allow_lazy_constant controls whether LazyConstantVariable can be returned
- # for int/float/bool/str. Set to False when called from LazyCache.realize()
- # to prevent double-wrapping (LazyVariableTracker containing LazyConstantVariable).
- self.allow_lazy_constant = allow_lazy_constant
- def __call__(self, value: object) -> VariableTracker:
- _t0 = time.time_ns()
- try:
- return self._call_impl(value)
- finally:
- self.tx.output.bytecode_tracing_timings.variable_builder_call_ns += (
- time.time_ns() - _t0
- )
- def _call_impl(self, value: object) -> VariableTracker:
- if value in self.tx.output.side_effects:
- side_effect_result = self.tx.output.side_effects[value]
- dup_guard = make_dupe_guard(self.source, side_effect_result.source)
- if dup_guard:
- self.install_guards(dup_guard)
- if isinstance(value, torch.nn.Module) and isinstance(
- side_effect_result, UnspecializedNNModuleVariable
- ):
- # This means that two nn module instances with different sources
- # have the same id. NN modules are somewhat special objects,
- # because we have to track their nn_module_stack for ease of
- # use. But if we don't do anything, we will just return the
- # older variable tracker with the older nn_module_stack. So,
- # lets return the old variable tracker but update its
- # nn_module_stack
- side_effect_result.set_nn_module_stack_source(self.source)
- return side_effect_result
- cached_vt = self.tx.output.variable_tracker_cache.get(self.source)
- if cached_vt:
- # If allow_lazy_constant=False but the cached VT is a lazy variable,
- # we need to rebuild to get a non-lazy version. This happens when
- # LazyConstantVariable.realize() calls VariableBuilder.
- if self.allow_lazy_constant or not isinstance(
- cached_vt, LazyVariableTracker
- ):
- return cached_vt
- vt = self._wrap(value)
- if vt.source is None:
- vt.source = self.source
- def _is_deduplicable_sym_variable(value: Any, vt: VariableTracker) -> bool:
- # Constants like 0, 1, 2, etc. can be unspecialized as SymNodeVariables sometimes, but we
- # should NOT track them. If we use a single SymNodeVariable instance to track them
- # across multiple uses, then guards created for one usage will incorrectly apply to
- # all other usages of that constant, leading to unnecessary recompilations.
- return (
- is_torch_sym(value) or isinstance(value, _DynamicScalar)
- ) and isinstance(vt, SymNodeVariable)
- if (
- (
- self._can_lift_attrs_to_inputs(vt)
- or _is_deduplicable_sym_variable(value, vt)
- )
- and value not in self.tx.output.side_effects
- and not is_wrapper_or_member_descriptor(value)
- ):
- vt = self.tx.output.side_effects.track_object_existing(value, vt)
- # Skip caching for JVP_NESTING source because
- # JvpIncrementNestingCtxManagerVariable hides global JVP mutation from
- # Dynamo, resulting in stale value. We attempted a fix in
- # https://github.com/pytorch/pytorch/pull/174329 but it exposed other
- # issues. This only affects cache hit rate, NOT correctness.
- if "JVP_NESTING" not in self.source.name:
- self.tx.output.variable_tracker_cache[self.source] = vt
- return vt
- def _can_lift_attrs_to_inputs(self, vt: VariableTracker) -> bool:
- return type(vt) in {
- TensorVariable,
- TensorWithTFOverrideVariable,
- UserDefinedObjectVariable,
- NumpyNdarrayVariable,
- TorchScriptObjectVariable,
- }
- def get_source(self) -> Source:
- return self.source
- def install_guards(self, *guards: Callable[..., Any]) -> dict[str, Any] | None:
- source = self.get_source()
- try:
- tmp = [source.make_guard(guard) for guard in guards]
- except NotImplementedError:
- return None
- install_guard(*tmp, skip=1)
- return {}
- @classmethod
- def _type_dispatch(cls) -> dict[object, Callable[..., Any]]:
- return cls._type_dispatch_impl(config.trace_numpy)
- @classmethod
- @functools.cache
- def _type_dispatch_impl(cls, trace_numpy: bool) -> dict[object, Callable[..., Any]]:
- # NB: Careful not to close over self to avoid ref cycle from lru_cache
- entries = [
- (
- (
- torch.Tensor,
- torch.nn.Parameter,
- torch._subclasses.FakeTensor,
- torch._subclasses.functional_tensor.FunctionalTensor,
- ),
- cls.wrap_tensor,
- ),
- (
- (tuple, list, odict_values, collections.deque, torch.Size),
- cls.wrap_listlike,
- ),
- (tuple_iterator, cls.wrap_tuple_iterator),
- (range_iterator, cls.wrap_range_iterator),
- ((slice, range), cls.wrap_slice_range),
- (tuple(common_constant_types), cls.wrap_literal),
- (re.Pattern, cls.wrap_regex_pattern),
- (weakref.ReferenceType, cls.wrap_weakref),
- (torch.utils.hooks.RemovableHandle, cls.wrap_removable_handle),
- (torch.jit.ScriptFunction, cls.wrap_jit_function),
- (types.MappingProxyType, cls.wrap_mapping_proxy),
- ]
- if trace_numpy and np:
- # pyrefly: ignore [bad-argument-type]
- entries.append((np.ndarray, cls.wrap_numpy_ndarray))
- # pyrefly: ignore [implicit-any]
- result = {}
- for ts, fn in entries:
- for t in ts if isinstance(ts, tuple) else (ts,):
- assert t not in result
- result[t] = fn
- return result
- def wrap_regex_pattern(self, value: re.Pattern[Any]) -> ConstantLikeVariable:
- # TODO(jansel): something like a REPR_MATCH might be more robust here
- self.install_guards(GuardBuilder.ID_MATCH)
- return ConstantLikeVariable(value)
- def wrap_weakref(self, value: weakref.ReferenceType[Any]) -> WeakRefVariable:
- self.install_guards(GuardBuilder.TYPE_MATCH)
- return WeakRefVariable.build(self.tx, value, source=self.source)
- def wrap_removable_handle(
- self, value: torch.utils.hooks.RemovableHandle
- ) -> NoReturn:
- # This means that the removable handle was created in some other frame.
- # Our current infra requires the hook to be registered and removed in
- # the same frame. So graph break.
- # Related test - PYTORCH_TEST_WITH_DYNAMO=1 python test/test_autograd.py -k TestAutograd.test_hooks
- unimplemented(
- gb_type="Attempted to represent unregistered RemovableHandle",
- context="",
- explanation="Dynamo attempted to build a representation of a torch.utils.hooks.RemovableHandle, "
- "which is not supported. This happens because the RemovableHandle was created in another frame.",
- hints=[],
- )
- def wrap_jit_function(self, value: Any) -> WrapperUserFunctionVariable:
- self.install_guards(GuardBuilder.TYPE_MATCH)
- return WrapperUserFunctionVariable(
- value, "_torchdynamo_inline", source=self.source
- )
- def wrap_mapping_proxy(self, value: Any) -> VariableTracker:
- self.install_guards(GuardBuilder.TYPE_MATCH)
- # This might be suboptimal compared to dict guards. But mappingproxy is
- # not very common, so its ok to guard on all keys.
- self.install_guards(GuardBuilder.MAPPING_KEYS_CHECK)
- all_const = all(ConstantVariable.is_literal(k) for k in value)
- if not all_const:
- unimplemented(
- gb_type="non-const keys in mappingproxy",
- context=f"non-const keys: {[k for k in value.keys() if not ConstantVariable.is_literal(k)]}", # noqa: SIM118
- explanation="Dynamo expects mappingproxy keys to be constants.",
- hints=[
- "Ensure your mappingproxy keys are constants (e.g. int, float, strings)",
- ],
- )
- def build_key_value(k: Any, v: Any) -> tuple[VariableTracker, VariableTracker]:
- key = ConstantVariable.create(k)
- source_key = k
- source_value = GetItemSource(self.get_source(), source_key)
- res_value = LazyVariableTracker.create(v, source_value)
- return key, res_value
- items = dict(build_key_value(k, v) for k, v in value.items())
- # Create a dict_vt to be used in the mapping proxy variable
- # pyrefly: ignore[bad-argument-type]
- dict_vt = ConstDictVariable(items, source=None)
- result = MappingProxyVariable(dict_vt, source=self.source)
- return self.tx.output.side_effects.track_mutable(value, result)
- @classmethod
- @functools.cache
- def _id_dispatch(
- cls,
- ) -> dict[int, Callable[["VariableBuilder", Any], VariableTracker]]:
- from ..comptime import comptime
- entries = [
- (comptime, lambda self, value: ComptimeVariable()),
- (
- dataclasses.fields,
- lambda self, value: LambdaVariable(
- _dataclasses_fields_lambda,
- source=self.source,
- **self.install_guards(GuardBuilder.CLOSURE_MATCH),
- ),
- ),
- (torch.__version__, lambda self, value: TorchVersionVariable()),
- ]
- # pyrefly: ignore [implicit-any]
- result = {}
- for ts, fn in entries:
- for t in ts if isinstance(ts, (tuple, list)) else (ts,):
- assert t not in result
- result[id(t)] = fn
- return result
- def _wrap(self, value: Any) -> VariableTracker:
- # import here to avoid circular dependencies
- from torch.utils._triton import (
- has_triton,
- has_triton_experimental_host_tma,
- has_triton_tensor_descriptor_host_tma,
- )
- from ..decorators import (
- DynamoConfigPatchProxy,
- ErrorOnGraphBreakDecoratorContextManager,
- )
- if has_triton():
- from triton.runtime.autotuner import Autotuner
- from triton.runtime.jit import JITFunction
- else:
- class JITFunction:
- pass
- class Autotuner:
- pass
- # default implementations, in case we don't have triton (or the wrong triton version)
- def create_1d_tma_descriptor() -> None:
- pass
- def create_2d_tma_descriptor() -> None:
- pass
- class TensorDescriptor:
- @staticmethod
- def from_tensor() -> None:
- pass
- def set_allocator() -> None:
- pass
- if has_triton_experimental_host_tma():
- from triton.tools.experimental_descriptor import ( # noqa: F811
- create_1d_tma_descriptor,
- create_2d_tma_descriptor,
- )
- if has_triton_tensor_descriptor_host_tma():
- from triton.tools.tensor_descriptor import TensorDescriptor # noqa: F811
- if has_triton():
- import triton as triton_mod
- if hasattr(triton_mod, "set_allocator"):
- set_allocator = triton_mod.set_allocator # noqa: F811
- # Handle exact type() match
- type_dispatch = self._type_dispatch().get(type(value))
- if type_dispatch is not None:
- return type_dispatch(self, value)
- # Handle exact id() match
- id_dispatch = self._id_dispatch().get(id(value))
- if id_dispatch is not None:
- return id_dispatch(self, value)
- # Everything else (NB: order matters!)
- if (
- isinstance(value, torch.Tensor)
- and type(value)
- not in (
- # These torch-native subclasses have overly restrictive
- # `__torch_function__` which prevents Dynamo from reading their
- # tensor attributes like `is_nested` or calling methods like
- # `_is_view`.
- torch.nn.parameter.UninitializedBuffer,
- torch.nn.parameter.UninitializedParameter,
- ExpandedWeight,
- )
- and type(value) not in config.nontraceable_tensor_subclasses
- ):
- if (
- type(value).__torch_dispatch__ is torch.Tensor.__torch_dispatch__
- or is_traceable_wrapper_subclass(value)
- ):
- return self.wrap_tensor(value)
- if is_namedtuple(value):
- self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
- output: list[VariableTracker] = [
- LazyVariableTracker.create(
- getattr(value, name),
- source=AttrSource(self.source, name),
- )
- for name in namedtuple_fields(type(value))
- ]
- tuple_vt = TupleVariable(
- output,
- source=self.source,
- mutation_type=ValueMutationExisting(),
- )
- result = NamedTupleVariable(
- output,
- tuple_cls=type(value),
- source=self.source,
- tuple_vt=tuple_vt,
- )
- return self.tx.output.side_effects.track_object_existing(value, result)
- elif istype(value, (dict, collections.defaultdict, collections.OrderedDict)):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- all_const = all(ConstantVariable.is_literal(k) for k in value)
- # For all_const, we don't have to guard on anything yet. We guard on
- # keys lazily by adding a dict_getitem entry for each accessed key.
- # For cases where we need to guard on all keys, we lazily put guards
- # during the dict call_method (check dicts.py)
- if not all_const:
- # Guard on the key order
- # This is not ideal, i.e., there is no need to guard on the key
- # order. But we guard on the key order because of the complexity
- #
- # 1) For non-constant objects, we can't save the key in the
- # guard context because it can be memory heavy. We can add
- # weakrefs but this complicates the accesses.
- #
- # 2) For non-constant objects, we also have to guard on the keys
- # (like TENSOR_MATCH on tensor). We might also have guards on
- # the attributes of the keys (like tensor.grad). To make this
- # work in tree structure is complicated.
- #
- # So, instead we guard on the key order. While guarding on key
- # order, we just save the indices and use it to access keys and
- # values. Indices are cheap to save.
- self.tx.output.guard_on_key_order.add(self.source)
- # We need all the keys to be hashable. We do this within the
- # _HashableTracker class in dicts.py
- def build_key_value(
- i: Any, k: Any, v: Any
- ) -> tuple[VariableTracker, VariableTracker]:
- base = self.get_source()
- if all_const:
- key = ConstantVariable.create(k)
- source_key = k
- else:
- source_key = ConstDictKeySource(base, i)
- key = LazyVariableTracker.create(k, source_key)
- source_value = DictGetItemSource(base, source_key)
- res_value = LazyVariableTracker.create(v, source_value)
- return key, res_value
- # Ensure that we call dict.keys and not value.keys (which can call
- # overridden keys method). In the C++ guards, we relied on
- # PyDict_Next to traverse the dictionary, which uses the internal
- # data structure and does not call the overridden keys method.
- result = dict(
- build_key_value(i, k, v)
- for i, (k, v) in enumerate(get_items_from_dict(value))
- )
- if istype(value, collections.defaultdict):
- factory_source = AttrSource(self.source, "default_factory")
- result = DefaultDictVariable(
- result, # type: ignore[arg-type]
- type(value),
- default_factory=VariableBuilder(self.tx, factory_source)(
- value.default_factory
- ),
- source=self.source,
- )
- else:
- result = ConstDictVariable(
- result, # type: ignore[arg-type]
- user_cls=type(value),
- source=self.source,
- )
- return self.tx.output.side_effects.track_mutable(value, result)
- elif isinstance(value, torch.nn.Module):
- return self.wrap_module(value)
- elif ConstantVariable.is_literal(value): # non-atomic literals
- return self.wrap_literal(value)
- elif isinstance(value, torch.overrides.TorchFunctionMode):
- var = TorchFunctionModeVariable(value, source=self.source)
- self.tx.output.side_effects.track_object_existing(value, var)
- return var
- elif istype(value, (set, OrderedSet)):
- if any(isinstance(x, torch.Tensor) for x in value):
- unimplemented(
- gb_type="Attempted to wrap a set with tensors",
- context="Python set containing torch.Tensor elements",
- explanation=(
- "Dynamo cannot trace sets of tensors. To get a stable ordering, "
- "Dynamo needs to convert the set into a list and the order might not be "
- "stable if the set contains tensors."
- ),
- hints=[
- "Use a dictionary where the keys are tensors.",
- *graph_break_hints.SUPPORTABLE,
- ],
- )
- self.install_guards(GuardBuilder.TYPE_MATCH)
- self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
- set_var_cls = SetVariable
- if istype(value, OrderedSet):
- # Guard on the internal dict of OrderedSet
- internal_dict_source = AttrSource(self.source, "_dict")
- install_guard(
- internal_dict_source.make_guard(GuardBuilder.DICT_KEYS_MATCH)
- )
- self.tx.output.guard_on_key_order.add(internal_dict_source)
- set_var_cls = OrderedSetVariable
- # The list gives a ordering for the set items. The ordering is based
- # on the Python hash and it is not related to object ordering inside
- # the set object. The order being incorrect at runtime will lead to
- # a recompilation.
- L = list(value)
- items = [
- LazyVariableTracker.create(
- v, source=NonSerializableSetGetItemSource(self.source, i)
- )
- for i, v in enumerate(L)
- ]
- result = set_var_cls(items, source=self.source)
- return self.tx.output.side_effects.track_object_existing(value, result)
- elif istype(value, frozenset) and all(
- (
- # For DBR quantization, we could get a frozenset of torch funcs.
- (type(x) is types.BuiltinMethodType and x.__module__ == "torch")
- or
- # Another commonly used frozenset of types.
- x in torch.utils._pytree.BUILTIN_TYPES
- or
- # For activation checkpointing, we could get a frozenset of torch ops.
- isinstance(x, (OpOverload, OpOverloadPacket))
- )
- for x in value
- ):
- # For the limited cases of frozenset here, we know the items won't
- # change across runs, so we can safely create sourceless VTs for
- # them and guard on the frozenset contents via EQUALS_MATCH.
- # TODO support source for sets and remove the special logics here.
- items = [SourcelessBuilder.create(self.tx, v) for v in value]
- self.install_guards(GuardBuilder.EQUALS_MATCH)
- return FrozensetVariable(items, source=self.source)
- elif isinstance(
- value, (enum.Enum, torch.DispatchKey, torch._C._functorch.TransformType)
- ):
- self.install_guards(GuardBuilder.ID_MATCH)
- return EnumVariable(value=value, source=self.source)
- elif DebuggingVariable.is_reorderable_logging_function(value):
- # Put this above builtin_callable so that print() can be handled
- # along with other builtin debugging functions
- self.install_guards(GuardBuilder.BUILTIN_MATCH)
- return DebuggingVariable(value, source=self.source)
- elif callable(value) and any(
- value is fn for fn in torch._dynamo.config.ignore_logging_functions
- ):
- # Treat ignored functions as full no-ops
- self.install_guards(GuardBuilder.ID_MATCH)
- return IgnoredFunctionVariable(value, source=self.source)
- elif isinstance(value, logging.Logger):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- return LoggingLoggerVariable(value, source=self.source)
- elif is_utils_checkpoint(value):
- return build_checkpoint_variable(source=self.source)
- elif is_invoke_subgraph(value):
- return build_invoke_subgraph_variable(source=self.source)
- elif LocalMapWrappedHigherOrderVariable.should_wrap_in_hop(value):
- return LocalMapWrappedHigherOrderVariable.build(source=self.source)
- elif isinstance(value, functools.partial):
- func_src = AttrSource(self.get_source(), "func")
- func_obj = VariableBuilder(self.tx, func_src)(value.func)
- args = []
- args_source = AttrSource(self.get_source(), "args")
- for i, arg in enumerate(value.args):
- args.append(
- VariableBuilder(self.tx, GetItemSource(args_source, i))(arg)
- )
- keywords = {}
- keywords_source = AttrSource(self.get_source(), "keywords")
- for k, v in value.keywords.items():
- if not ConstantVariable.is_literal(k):
- unimplemented(
- gb_type="functools.partial() with non-literal keyword",
- context=f"non-literal keyword: {k}",
- explanation="functools.partial() expects literal/string keywords",
- hints=[*graph_break_hints.USER_ERROR],
- )
- keywords[k] = VariableBuilder(
- self.tx, DictGetItemSource(keywords_source, k)
- )(v)
- install_guard(
- self.get_source().make_guard(GuardBuilder.TYPE_MATCH),
- keywords_source.make_guard(GuardBuilder.DICT_KEYS_MATCH),
- args_source.make_guard(GuardBuilder.SEQUENCE_LENGTH),
- )
- # Preserve cache_hash for SAC context_fn caching
- original_cache_hash = getattr(value, "cache_hash", None)
- return FunctoolsPartialVariable(
- func_obj, args, keywords, original_cache_hash=original_cache_hash
- )
- elif is_typing(value):
- # typing.List, typing.Mapping, etc.
- self.install_guards(GuardBuilder.ID_MATCH)
- return TypingVariable(
- value,
- source=self.source,
- )
- elif np is not None and isinstance(value, np.generic):
- # numpy array scalars: convert to 0D arrays
- return self.wrap_numpy_ndarray(np.asarray(value))
- elif trace_rules.is_numpy(value):
- assert np
- if istype(value, types.MethodType):
- # Dont guard on cython functions as they dont change ids
- if inspect.isfunction(value.__func__):
- install_guard(
- AttrSource(self.source, "__func__").make_guard(
- GuardBuilder.CLOSURE_MATCH
- )
- )
- elif inspect.isclass(value):
- self.install_guards(GuardBuilder.CLASS_MATCH)
- elif inspect.isfunction(value):
- self.install_guards(GuardBuilder.CLOSURE_MATCH)
- elif callable(value):
- self.install_guards(GuardBuilder.ID_MATCH)
- else:
- self.install_guards(GuardBuilder.TYPE_MATCH)
- return NumpyVariable(value, source=self.source)
- elif trace_rules.is_numpy_dtype(value):
- self.install_guards(GuardBuilder.ID_MATCH)
- return NumpyDTypeVariable(value, source=self.source)
- elif trace_rules.is_numpy_type_info(value):
- if isinstance(value, np.iinfo):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- dt_source = AttrSource(self.source, "dtype")
- install_guard(dt_source.make_guard(GuardBuilder.ID_MATCH))
- else:
- self.install_guards(GuardBuilder.ID_MATCH)
- return ConstantLikeVariable(value, source=self.source)
- # NB: These can't be put in type_dispatch, they have to run later
- elif CollectiveFunctionRewriteVariable.can_rewrite(value):
- self.install_guards(GuardBuilder.CLOSURE_MATCH)
- return CollectiveFunctionRewriteVariable.create(
- self.tx,
- value,
- source=self.source,
- )
- elif istype(value, torch.autograd.function.FunctionMeta):
- self.install_guards(GuardBuilder.CLASS_MATCH)
- return AutogradFunctionVariable(
- value,
- source=self.source,
- )
- elif isinstance(value, torch.autograd.function.FunctionCtx):
- actual_saved_tensors = None
- try:
- # type: ignore[attr-defined]
- actual_saved_tensors = value.saved_tensors
- except RuntimeError:
- pass
- saved_tensors = []
- guards = [self.source.make_guard(GuardBuilder.TYPE_MATCH)]
- if isinstance(actual_saved_tensors, tuple):
- saved_tensors_source = AttrSource(self.source, "saved_tensors")
- guards.append(
- saved_tensors_source.make_guard(GuardBuilder.SEQUENCE_LENGTH)
- )
- for i, v in enumerate(actual_saved_tensors):
- saved_tensors.append(
- VariableBuilder(
- self.tx, GetItemSource(saved_tensors_source, i)
- )(v)
- )
- install_guard(*guards)
- return self.tx.output.side_effects.track_object_existing(
- value,
- AutogradFunctionContextVariable(
- value,
- source=self.source,
- saved_tensors=SavedTensorBox(saved_tensors),
- ),
- )
- elif (
- isinstance(value, types.MethodType)
- and istype(
- getattr(value, "__self__", None), torch.autograd.function.FunctionMeta
- )
- and getattr(value, "__name__", "") == "apply"
- and value == getattr(value.__self__, "apply", None)
- ):
- # handle aliased autograd function `apply` calls
- install_guard(
- AttrSource(self.get_source(), "__func__").make_guard(
- GuardBuilder.CLOSURE_MATCH
- )
- )
- return GetAttrVariable(
- AutogradFunctionVariable(
- value.__self__,
- source=AttrSource(self.source, member="__self__"),
- ),
- "apply",
- )
- elif isinstance(value, torch._C._ImperativeEngine):
- self.install_guards(GuardBuilder.ID_MATCH)
- return AutogradEngineVariable(value, source=self.source)
- elif (
- value
- is torch._dynamo.external_utils.FakeCompiledAutogradEngine._exec_final_callbacks_stub
- ):
- self.install_guards(GuardBuilder.CLOSURE_MATCH)
- return LambdaVariable(
- lambda: UserFunctionVariable(
- torch._dynamo.external_utils.FakeCompiledAutogradEngine.exec_final_callbacks,
- ).call_function(
- self.tx,
- (self.tx.output.side_effects.get_ca_final_callbacks_var(),),
- {},
- )
- )
- elif isinstance(value, DynamoConfigPatchProxy):
- return DynamoConfigPatchVariable(value.changes)
- elif isinstance(value, ErrorOnGraphBreakDecoratorContextManager):
- return ErrorOnGraphBreakVariable(value.error_on_graph_break)
- elif callable(value) and trace_rules.lookup_callable(value) is not None:
- if trace_rules.is_callable_allowed(value):
- self.tx.output.has_user_defined_allowed_in_graph = True
- # type: ignore[attr-defined]
- return trace_rules.lookup_callable(value).create_with_source(
- value, source=self.source
- )
- elif np and isinstance(value, np.number):
- return self.wrap_unspecialized_primitive(value)
- elif isinstance(value, HigherOrderOperator):
- if value is torch._higher_order_ops.invoke_subgraph:
- unimplemented(
- gb_type="Attempted to wrap torch._higher_order_ops.invoke_subgraph",
- context="",
- explanation="Directly using invoke_subgraph is not supported. Use nested_compile_region",
- hints=[],
- )
- self.install_guards(GuardBuilder.TYPE_MATCH)
- return TorchHigherOrderOperatorVariable.make(value, source=self.source)
- elif isinstance(value, torch.cuda.StreamContext):
- self.install_guards(GuardBuilder.ID_MATCH)
- stream_source = AttrSource(self.source, "stream")
- stream_var = VariableBuilder(self.tx, stream_source)(value.stream)
- # type: ignore[arg-type]
- return StreamContextVariable.create(self.tx, stream_var)
- elif isinstance(value, torch.Stream):
- # This refers to the device-agnostic torch.Stream
- self.install_guards(GuardBuilder.TYPE_MATCH)
- index = register_user_object(value, self.source)
- stream_proxy = self.tx.output.create_proxy(
- "call_function", get_external_object_by_index, (index,), {}
- )
- set_example_value(stream_proxy.node, value)
- var = StreamVariable(
- stream_proxy, value, source=self.source, user_object_index=index
- )
- return self.tx.output.side_effects.track_object_existing(value, var)
- elif isinstance(value, (torch._C._SDPAParams)):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- return SDPAParamsVariable.create(self.tx, value, self.source)
- elif isinstance(value, torch._functorch.pyfunctorch.FuncTorchInterpreter):
- self.install_guards(GuardBuilder.ID_MATCH)
- return FuncTorchInterpreterVariable(value)
- elif isinstance(value, torch.Event):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- index = register_user_object(value, self.source)
- event_proxy = self.tx.output.create_proxy(
- "call_function",
- get_external_object_by_index,
- (index,),
- {},
- )
- set_example_value(event_proxy.node, value)
- return EventVariable(
- event_proxy,
- value,
- index,
- source=self.source,
- )
- elif (
- istype(value, contextlib.nullcontext)
- and inspect.getattr_static(value, "enter_result", None) is None
- ):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- return NullContextVariable(source=self.source)
- elif KeyedJaggedTensorVariable.is_matching_object(value):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- result = KeyedJaggedTensorVariable(value, source=self.source)
- # TODO: this doing it manually is bad
- return self.tx.output.side_effects.track_object_existing(value, result)
- elif isinstance(value, torch.optim.Optimizer):
- self.install_guards(GuardBuilder.ID_MATCH)
- self.source = OptimizerSource(self.source)
- return OptimizerVariable(value, source=self.source)
- elif isinstance(value, torch.DispatchKeySet):
- self.install_guards(GuardBuilder.DISPATCH_KEY_SET_MATCH)
- return DispatchKeySetVariable(value)
- elif WorldMetaClassVariable.is_group_member_type(value):
- return WorldMetaClassVariable(value, source=self.source)
- elif ProcessGroupVariable.is_process_group(value):
- self.install_guards(GuardBuilder.ID_MATCH)
- return ProcessGroupVariable(value, source=self.source)
- elif DeviceMeshVariable.is_device_mesh(value):
- # TODO: see if we need to add custom guard instead of a simple ID_MATCH
- self.install_guards(GuardBuilder.EQUALS_MATCH)
- return DeviceMeshVariable(value, source=self.source)
- elif PlacementClassVariable.is_placement_type(value):
- # TODO: see if we need to add custom guard instead of a simple ID_MATCH
- self.install_guards(GuardBuilder.ID_MATCH)
- return PlacementClassVariable(value, source=self.source)
- elif PlacementVariable.is_placement(value):
- # TODO: see if we need to add custom guard instead of a simple ID_MATCH
- self.install_guards(GuardBuilder.EQUALS_MATCH)
- return PlacementVariable(
- value,
- source=self.source,
- )
- elif value is OrderedSet:
- self.install_guards(GuardBuilder.ID_MATCH)
- return OrderedSetClassVariable()
- elif (
- id(value) in ITERTOOLS_TYPE_IDS
- and id(value) not in ITERTOOLS_POLYFILLED_TYPE_IDS
- ):
- self.install_guards(GuardBuilder.CLASS_MATCH)
- return ItertoolsVariable(value, source=self.source)
- elif isinstance(value, _DynamicScalar):
- is_int = isinstance(value, DynamicInt)
- source = DynamicScalarSource(self.source, is_int)
- if id(value) in self.tx.output.root_tracer.dynamic_scalar_nodes:
- # If we've already seen this dynamic scalar, reuse the existing
- # SymInt/SymFloat node.
- node = self.tx.output.root_tracer.dynamic_scalar_nodes[id(value)]
- else:
- sym = self.tx.output.shape_env.create_unspecified_symbol(
- value.real, # type: ignore[attr-defined]
- source=source,
- dynamic_dim=DimDynamic.DYNAMIC,
- )
- node = self.tx.output.shape_env.create_symintnode(
- sym,
- hint=value.real, # type: ignore[attr-defined]
- source=source,
- )
- # Bind to graph input
- sym_node_proxy = self.tx.output.root_tracer.create_graph_input(
- re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
- type(node),
- node,
- source=source,
- )
- sym_node_proxy.node.meta["grapharg"] = GraphArg(
- source,
- node,
- False,
- None,
- is_tensor=False,
- example_strong_ref=node,
- )
- sym_expr = node.node.expr
- assert isinstance(sym_expr, sympy.Symbol), (
- f"{sym_expr} is not a basic Symbol."
- )
- self.tx.output.tracked_fakes.append(TrackedFake(node, source, None))
- return SymNodeVariable.create(self.tx, sym_node_proxy, node)
- elif is_torch_sym(value):
- # Note: this doesn't handle nested symints.
- # For SymBool input, we reuse the infra for SymInt by simulating SymBool with a SymInt in dynamo.
- # Concretely,
- # 1. We create a SymInt in dynamo's shape_env, whose source is constructed as ConvertIntSource(self.source).
- # so that guards on the SymInts can be effectively applied on the original SymBool in user program.
- # 2. We create a SymBool based on the SymInt in dynamo's ShapeEnv. Because the original user program
- # depends on the value being a SymBool. This allows dynamo to interpret the user's program correctly.
- source = (
- self.source
- if isinstance(value, torch.SymInt)
- else ConvertIntSource(self.source)
- )
- new_symint = None
- if value.node.has_hint():
- new_symint = (
- self.tx.output.shape_env.create_unspecified_symint_and_symbol(
- int(value.node.hint),
- source,
- dynamic_dim=DimDynamic.DYNAMIC,
- )
- )
- else:
- if isinstance(value, torch.SymBool):
- # We need to create an unbacked symint to replace the unbacked symbool.
- new_symint = self.tx.output.shape_env.create_unbacked_symint()
- else:
- # TODO (yidi): we need to figure out a way to propagate the guards
- # we accumulated when tracing the subggraph to outer shape_env. For normal symints,
- # this is automatically done by evaluating the guards once but this
- # will cause data-dependent error when we evaluate the outer unbacked symints.
- # The test case that triggers this graph break is test_cond_unbacked_symint_closure
- unimplemented(
- gb_type="Attempted to wrap unbacked SymInt",
- context="",
- explanation="Unbacked SymInt input is not supported yet.",
- hints=[*graph_break_hints.SUPPORTABLE],
- )
- assert new_symint is not None
- sym_node_proxy = self.tx.output.root_tracer.create_graph_input(
- re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
- type(new_symint),
- new_symint,
- source=source,
- )
- sym_node_proxy.node.meta["grapharg"] = GraphArg(
- source,
- new_symint,
- False,
- None,
- is_tensor=False,
- example_strong_ref=new_symint,
- )
- # We bind the new_symint to graph input.
- sym_expr = new_symint.node.expr
- assert isinstance(sym_expr, sympy.Symbol), (
- f"{sym_expr} is not a basic Symbol."
- )
- self.tx.output.tracked_fakes.append(TrackedFake(new_symint, source, None))
- tracing_symint = (
- new_symint if isinstance(value, torch.SymInt) else new_symint == 1
- ) # cast it back to symbool for tracing
- return SymNodeVariable(sym_node_proxy, tracing_symint)
- elif isinstance(value, (JITFunction, Autotuner)):
- self.install_guards(GuardBuilder.ID_MATCH)
- return TritonKernelVariable(
- value,
- None, # No kernel idx provided
- None, # No grid provided
- source=self.source,
- )
- elif value is create_1d_tma_descriptor:
- return CreateTMADescriptorExperimentalVariable(rank=1)
- elif value is create_2d_tma_descriptor:
- return CreateTMADescriptorExperimentalVariable(rank=2)
- elif value is TensorDescriptor.from_tensor:
- return CreateTMADescriptorStableVariable()
- elif value is set_allocator:
- return TritonSetAllocatorSkipVariable(value)
- elif isinstance(value, torch.amp.autocast_mode.autocast):
- self.install_guards(GuardBuilder.ID_MATCH)
- return AutocastModeVariable(
- target_values=[
- value.device,
- value.fast_dtype,
- value._enabled,
- value._cache_enabled,
- ],
- source=self.source,
- )
- elif TorchCtxManagerClassVariable.is_matching_cls(value):
- if inspect.isclass(value):
- self.install_guards(GuardBuilder.CLASS_MATCH)
- elif inspect.isfunction(value):
- self.install_guards(GuardBuilder.CLOSURE_MATCH)
- return TorchCtxManagerClassVariable(value, source=self.source)
- elif inspect.getattr_static(value, "__script_if_tracing_wrapper", False):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- return WrapperUserFunctionVariable(
- value, "__original_fn", source=self.source
- )
- elif is_lru_cache_wrapped_function(value):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- return WrapperUserFunctionVariable(value, "__wrapped__", source=self.source)
- elif value is sys.exc_info or (
- sys.version_info >= (3, 11) and value is sys.exception
- ):
- return SysFunctionVariable(value, source=self.source)
- elif is_function_or_wrapper(value) and inspect.getattr_static(
- value, "_torchdynamo_inline", False
- ):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- return WrapperUserFunctionVariable(
- value, "_torchdynamo_inline", source=self.source
- )
- elif value is functools.wraps:
- self.install_guards(GuardBuilder.ID_MATCH)
- return FunctoolsWrapsVariable(value, source=self.source)
- elif value is collections.namedtuple:
- self.install_guards(GuardBuilder.ID_MATCH)
- return CollectionsNamedTupleFunction(value, source=self.source)
- elif isinstance(
- value, types.BuiltinMethodType
- ) and BuiltinMethodVariable.is_supported_builtin_method(value):
- self.install_guards(GuardBuilder.ID_MATCH)
- return BuiltinMethodVariable(value, source=self.source)
- elif is_function(value) and value in (float.fromhex, float.hex):
- self.install_guards(GuardBuilder.ID_MATCH)
- return GetAttrVariable(
- BuiltinVariable(float, source=self.source),
- value.__name__,
- )
- elif is_function_or_wrapper(value):
- value, attr_name = unwrap_with_attr_name_if_wrapper(value)
- # For these wrappers, Dynamo points to the wrapped function,
- # so source needs to be updated as well.
- if attr_name is not None:
- self.source = AttrSource(self.source, attr_name)
- # type: ignore[attr-defined]
- return trace_rules.lookup(value).create_with_source(
- value, source=self.source
- )
- elif value is random.Random:
- self.install_guards(GuardBuilder.ID_MATCH)
- return RandomClassVariable(source=self.source)
- elif istype(value, random.Random) and RandomVariable.is_supported_random_obj(
- value
- ):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- result = RandomVariable(value, source=self.source)
- self.tx.output.side_effects.track_mutable(value, result)
- return result
- # Don't use istype, since some python modules are not subclasses of types.ModuleType directly.
- # E.g, type(torch.ops) -> <class 'torch._ops._Ops'>,
- # type(torch.backends.cudnn) -> <class 'torch.backends.cudnn.CudnnModule'>
- elif isinstance(value, (types.ModuleType, replay_record.DummyModule)):
- self.install_guards(GuardBuilder.MODULE_MATCH)
- result = PythonModuleVariable(
- # type: ignore[arg-type]
- value,
- source=self.source,
- )
- self.tx.output.side_effects.track_object_existing(value, result)
- return result
- elif isinstance(value, types.MethodType) and isinstance(
- value.__self__, (torch.nn.Module, torch.utils._pytree.TreeSpec)
- ):
- # don't let MethodTypes fall through to UserDefinedObject,
- # which doesn't support 'CALL_FUNCTION'
- # TODO(whc): Why do we limit this to methods on NNModules?
- # I don't have a good reason for this, but it preserves the existing behavior
- # for MBartForConditionalGeneration, which generates many graph breaks and OOMs otherwise.
- # I suspect we probably want to relax this check and dig deeper there.
- # In order to construct a MethodVariable in Dynamo, we start with an actual method obj from python,
- # but need to separately wrap its underlying `__func__` and its `self` argument. We wrap `self` here
- # and then `__func__` gets wrapped inside UserMethodVariable.
- self_obj = VariableBuilder(
- self.tx, source=AttrSource(self.source, "__self__")
- )(value.__self__)
- assert self_obj and isinstance(self_obj, VariableTracker), (
- "Failed to produce a valid self obj"
- )
- return UserMethodVariable(
- value.__func__,
- self_obj,
- source=self.source,
- )
- elif isinstance(value, types.GetSetDescriptorType):
- # GetSet descriptors are C functions attached to an attribute lookup
- # using PyGetSetDef. Python, on attribute lookup, can decide to
- # create a new object on the fly, and therefore the `id` of the
- # descriptors is not guaranteed to be same for different attribute
- # accesses. Since these are unlikely to change during the program
- # execution, we can skip guarding on them.
- return GetSetDescriptorVariable(value)
- elif isinstance(value, types.MethodWrapperType):
- # Method-wrappers are written in C, and they are not guaranteed to
- # return the same object on attribute lookup. Therefore, we cannot
- # insert a ID_MATCH guard here. method-wrappers are very
- # unlikely to change, so its ok to skip the guard here.
- return MethodWrapperVariable(value)
- elif issubclass(type(value), type) and issubclass(value, BaseException):
- # match user defined exceptions
- self.install_guards(GuardBuilder.ID_MATCH)
- return UserDefinedExceptionClassVariable(value)
- elif issubclass(type(value), type):
- if value in (
- torch.utils.hooks.BackwardHook,
- torch.nn.Parameter,
- torch.nn.Buffer,
- ):
- # TODO(jansel): combine this case with the one above
- # type: ignore[attr-defined]
- return trace_rules.lookup(value).create_with_source(
- value, source=self.source
- )
- if value is torch.autograd._unsafe_preserve_version_counter:
- self.install_guards(GuardBuilder.CLASS_MATCH)
- return PreserveVersionContextVariable.constructor(self.tx)
- if (
- # `value` must be a strict subclass of `torch.Tensor`
- issubclass(value, torch.Tensor)
- and value is not torch.Tensor
- # `TensorSubclassVariable` is not for subclass that overrides
- # `torch_dispatch`.
- and value.__torch_dispatch__ is torch.Tensor.__torch_dispatch__
- # `TensorSubclassVariable` would lead to construction of
- # `TensorWithTFOverrideVariable`, but we don't want that for
- # traceable wrapper subclasses (we wrap those subclass instances
- # into `TensorVariable`).
- and not is_traceable_wrapper_subclass_type(value)
- ):
- return TensorSubclassVariable(value, source=self.source)
- if not is_from_closure_source(self.source):
- # For closure source, the variable comes from LOAD_SUPER_ATTR,
- # which calls self.__class__. This is internal Cpython
- # implementation, and it is rare for the user to modify
- # self.__class__ manually.
- # For other cases, this is a userdefined class, so install an
- # ID_MATCH even if its a global variable.
- self.install_guards(GuardBuilder.CLASS_MATCH)
- if is_opaque_type(value):
- return OpaqueObjectClassVariable(
- value,
- source=self.source,
- )
- if isinstance(value, type) and issubclass(value, enum.Enum):
- return UserDefinedEnumClassVariable(
- value,
- source=self.source,
- )
- return UserDefinedClassVariable(
- value,
- source=self.source,
- )
- elif TorchScriptObjectVariable.is_matching_cls(type(value)):
- from ..source import (
- FlattenScriptObjectSource,
- ScriptObjectQualifiedNameSource,
- )
- # type: ignore[arg-type]
- if torch._library.fake_class_registry.tracing_with_real(value):
- proxy = self.tx.output.root_tracer.create_graph_input(
- re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
- type(value),
- value,
- source=self.source,
- )
- # setting is_unspecialized=False to not insert a as_tensor call in reconstruct by default
- # setting example to be real value because these example values will be used
- # as example_inputs for user compiler.
- proxy.node.meta["grapharg"] = GraphArg(
- self.source,
- value,
- False,
- None,
- False,
- value, # type: ignore[arg-type]
- )
- return TorchScriptObjectVariable.create(
- proxy,
- value,
- source=self.source,
- )
- if is_opaque_value_type(type(value)):
- # Value-type: guard on equality (will use __eq__)
- self.install_guards(GuardBuilder.CONSTANT_MATCH)
- elif is_opaque_reference_type(type(value)):
- # Reference-type: guard only on type, and registered guard_fn
- self.install_guards(GuardBuilder.TYPE_MATCH)
- self.install_guards(GuardBuilder.OPAQUE_OBJ_GUARD_FN_MATCH)
- elif not hasattr(value, "__obj_flatten__"):
- # This exists to allow a smoother transition.
- # The implications are:
- # The script objects won't be tracked as proxies.
- # Methods on these objects won't show up in the graph.
- # The original script object might be mutated.
- return self.wrap_user_defined(value)
- else:
- # Install the guards on the fully qualified name of the script object
- LazyVariableTracker.realize_all(
- VariableBuilder(
- self.tx, ScriptObjectQualifiedNameSource(self.source)
- )(
- value._type().qualified_name() # type: ignore[attr-defined]
- )
- )
- # Install the guards on the content of the script object by setting the source
- # to be FlattenScriptObjectSource, which calls __obj_flatten__() to get the contents.
- LazyVariableTracker.realize_all(
- VariableBuilder(self.tx, FlattenScriptObjectSource(self.source))(
- value.__obj_flatten__()
- )
- )
- fake_script_obj = torch._library.fake_class_registry.maybe_to_fake_obj(
- self.tx.output.fake_mode, value
- )
- if is_opaque_value_type(type(value)) and not should_hoist(type(value)):
- proxy = value
- else:
- proxy = self.tx.output.root_tracer.create_graph_input(
- re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
- type(value),
- fake_script_obj,
- source=self.source,
- )
- # setting is_unspecialized=False to not insert a as_tensor call in reconstruct by default
- # setting example to be real value because these example values will be used
- # as example_inputs for user compiler.
- proxy.node.meta["grapharg"] = GraphArg(
- self.source,
- value, # type: ignore[arg-type]
- False,
- None,
- False,
- fake_script_obj, # type: ignore[arg-type]
- )
- return TorchScriptObjectVariable.create(
- proxy, # pyrefly: ignore[bad-argument-type]
- fake_script_obj,
- source=self.source,
- )
- elif (
- isinstance(value, (dict, collections.OrderedDict))
- and type(value).__new__ is dict.__new__
- ):
- # Construct a dict_vt that will reside inside the UserDefinedDictVariable
- self.install_guards(GuardBuilder.TYPE_MATCH)
- self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
- # Guard on the key order
- self.tx.output.guard_on_key_order.add(self.source)
- # We need all the keys to be hashable. We do this within the
- # _HashableTracker class in dicts.py
- def build_key_value(
- i: Any, k: Any, v: Any
- ) -> tuple[VariableTracker, VariableTracker]:
- base = self.get_source()
- source_key = ConstDictKeySource(base, i)
- key = LazyVariableTracker.create(k, source_key)
- source_value = DictSubclassGetItemSource(base, source_key)
- res_value = LazyVariableTracker.create(v, source_value)
- return key, res_value
- # Ensure that we call dict.keys and not value.keys (which can call
- # overridden keys method). In the C++ guards, we relied on
- # PyDict_Next to traverse the dictionary, which uses the internal
- # data structure and does not call the overridden keys method.
- result = dict(
- build_key_value(i, k, v)
- for i, (k, v) in enumerate(get_items_from_dict(value))
- )
- dict_vt = ConstDictVariable(
- # pyrefly: ignore[bad-argument-type]
- result,
- user_cls=(
- collections.OrderedDict
- if isinstance(value, collections.OrderedDict)
- else dict
- ),
- mutation_type=ValueMutationExisting(),
- source=self.source,
- )
- # Force this to reconstruct on mutation to keep the reconstruction
- # bytecode simple
- dict_vt.should_reconstruct_all = True
- result = UserDefinedDictVariable(value, dict_vt=dict_vt, source=self.source)
- return self.tx.output.side_effects.track_object_existing(value, result)
- elif isinstance(value, tuple):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
- # NB - Be careful in not triggering user code. Guards also work on
- # the underlying tuple data structure.
- output = [
- LazyVariableTracker.create(
- tuple.__getitem__(value, i),
- source=GetItemSource(self.get_source(), i),
- )
- for i in range(tuple.__len__(value))
- ]
- tuple_vt = TupleVariable(
- output, # type: ignore[arg-type]
- source=self.source,
- mutation_type=ValueMutationExisting(),
- )
- result = UserDefinedTupleVariable(
- value, tuple_vt=tuple_vt, source=self.source
- )
- return self.tx.output.side_effects.track_object_existing(value, result)
- elif isinstance(value, list):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
- # NB - Be careful in not triggering user code. Guards also work on
- # the underlying list data structure.
- output = [
- LazyVariableTracker.create(
- list.__getitem__(value, i),
- source=ListGetItemSource(self.get_source(), i),
- )
- for i in range(list.__len__(value))
- ]
- list_vt = ListVariable(
- output, # type: ignore[arg-type]
- source=self.source,
- mutation_type=ValueMutationExisting(),
- )
- result = UserDefinedListVariable(value, list_vt=list_vt, source=self.source)
- return self.tx.output.side_effects.track_object_existing(value, result)
- elif isinstance(value, (set, frozenset)):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
- L = list(dict.fromkeys(value))
- output = [
- LazyVariableTracker.create(
- list.__getitem__(L, i),
- source=NonSerializableSetGetItemSource(self.get_source(), i),
- )
- for i in range(list.__len__(L))
- ]
- if isinstance(value, set):
- set_vt_cls = SetVariable
- else:
- assert isinstance(value, frozenset)
- set_vt_cls = FrozensetVariable
- set_vt = set_vt_cls(
- output, source=self.source, mutation_type=ValueMutationExisting()
- )
- result = UserDefinedSetVariable(value, set_vt=set_vt, source=self.source)
- return self.tx.output.side_effects.track_object_existing(value, result)
- elif issubclass(type(value), MutableMapping):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- result = MutableMappingVariable(value, source=self.source)
- return self.tx.output.side_effects.track_object_existing(value, result)
- elif is_frozen_dataclass(value):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- result = FrozenDataClassVariable.create(self.tx, value, source=self.source)
- return self.tx.output.side_effects.track_object_existing(value, result)
- elif isinstance(value, dict_keys):
- if all(ConstantVariable.is_literal(k) for k in value):
- # If the dict_keys object is passed from outside the compile region, it must either be passed along with
- # the corresponding dict object or treated as a set (when only the keys are passed into the compiled region).
- # - If it is passed along with the dict, the dict object itself is already guarded.
- # - If only the dict_keys object is passed, we add EQUALS_MATCH and SEQUENCE_LENGTH guards
- # to ensure it remains unchanged across multiple runs.
- items = [SourcelessBuilder.create(self.tx, v) for v in value]
- install_guard(
- self.get_source().make_guard(GuardBuilder.SEQUENCE_LENGTH),
- self.get_source().make_guard(GuardBuilder.EQUALS_MATCH),
- )
- return DictKeySetVariable(items, source=self.source)
- else:
- unimplemented(
- gb_type="non-const keys in dict_keys",
- context=f"non-const keys: {[k for k in value if not ConstantVariable.is_literal(k)]}",
- explanation="Dynamo expects dict_keys keys to be constants.",
- hints=[
- "Ensure your dict_keys keys are constants (e.g. int, float, strings)",
- ],
- )
- elif IntWrapperVariable.is_matching_object(value):
- from torch.export.dynamic_shapes import _DimHintType
- if value.dynamism is None or value.dynamism.type == _DimHintType.STATIC:
- return self.wrap_symint(value.val)
- elif value.dynamism.type == _DimHintType.DYNAMIC:
- log.debug(
- "%s marked %s via IntWrapper",
- self.source.name,
- DimDynamic.DYNAMIC,
- )
- return self.wrap_symint(
- value.val,
- dynamism=DimDynamic.DYNAMIC,
- context=SymIntSymbolicContext(
- constraint=RelaxedUnspecConstraint(warn_only=False)
- ),
- )
- elif value.dynamism.type == _DimHintType.AUTO:
- log.debug(
- "%s marked %s via IntWrapper",
- self.source.name,
- DimDynamic.DYNAMIC,
- )
- return self.wrap_symint(value.val, dynamism=DimDynamic.DYNAMIC)
- else:
- raise RuntimeError(f"Undefined dynamism {value.dynamism}")
- elif istype(value, object):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- return ObjectVariable(value, source=self.source)
- else:
- return self.wrap_user_defined(value)
- def wrap_user_defined(self, value: Any) -> VariableTracker:
- self.install_guards(GuardBuilder.TYPE_MATCH)
- if InspectVariable.is_matching_object(value):
- # Skip guards on inspect related variable trackers because they are
- # not important for recompiles (something else will also change to
- # cause recompiles) and can cause a large number of OBJECT_ALIASING
- # guards.
- result = InspectVariable(value, source=SkipGuardSource(self.source))
- else:
- result = UserDefinedObjectVariable(value, source=self.source)
- if not SideEffects.cls_supports_mutation_side_effects(type(value)):
- # don't allow STORE_ATTR mutation with custom __setattr__
- return result
- return self.tx.output.side_effects.track_object_existing(value, result)
- def wrap_listlike(
- self, value: Union[tuple[Any, ...], list[Any], odict_values, NamedTuple]
- ) -> VariableTracker:
- if config.specialize_int and type(value) is torch.Size:
- self.install_guards(GuardBuilder.CONSTANT_MATCH)
- return ConstantVariable.create(value=value)
- # One can index a tensor with a list/tuple. Therefore, we need to
- # have a stricter match.
- self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
- # Tuples are immutable objects, so we should mark its items static. This
- # avoids wrapping of tuple items as symints. This helps for nn module
- # attributes like conv2d strides, dilations.
- if (
- istype(value, tuple)
- and all(ConstantVariable.is_literal(item) for item in value)
- and self.source.guard_source.is_unspecialized_nn_module()
- ):
- self.install_guards(GuardBuilder.CONSTANT_MATCH)
- return TupleVariable([ConstantVariable.create(item) for item in value])
- output = [
- LazyVariableTracker.create(
- item,
- source=GetItemSource(self.get_source(), i),
- )
- for i, item in enumerate(value)
- ]
- maybe_gm = self.tx.output.local_scope.get("self")
- if isinstance(
- self.source, LocalSource
- ) and self.source.local_name in get_locals_to_steal(maybe_gm):
- # The input tensor list to dynamo from compiled autograd may contain activations
- # which are freed as they are used in inductor. Dynamo's default behavior is to
- # lift all tensors to the graph inputs, but this will cause dynamo to hold an
- # extra reference to the activation tensors and increase peak memory usage.
- # To allow freeing ASAP, we keep the list as graph argument to the dynamo output
- # graph, and unpack it locally.
- # e.g. instead of `def forward(self, L_inputs_0_, L_inputs_1_, ...):`, we have
- # `def forward(self, L_inputs_):`
- source = self.source
- assert isinstance(value, list)
- tensor_list_proxy = self.tx.output.root_tracer.create_graph_input(
- re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
- type(value),
- value,
- source=source,
- )
- tensor_list_proxy.node.meta["steal_arg"] = True
- list_variable = wrap_fx_proxy_cls(
- target_cls=TensorVariable,
- tx=self.tx,
- proxy=tensor_list_proxy,
- example_value=value,
- subclass_type=None,
- source=source,
- )
- # Apply relevant logic from `VariableTracker.build(value[i])`
- # (except for the `create_graph_input` stuff).
- guards = []
- # type: ignore[attr-defined]
- for i, tensor_variable in enumerate(list_variable.items):
- source_i = GetItemSource(base=source, index=i, index_is_slice=False)
- # access unpacked tensor from this list instead of from a lifted arg
- self.tx.output.input_source_to_var[source_i] = tensor_variable
- tensor_variable.proxy.node.meta["tensor_dict"] = _extract_tensor_dict(
- value[i]
- )
- guard = functools.partial(
- GuardBuilder.TENSOR_MATCH, value=TensorWeakRef(value[i])
- )
- guards.append(source_i.make_guard(guard))
- install_guard(*guards, skip=1)
- grapharg = GraphArg(
- source,
- value,
- pass_arg_as_tensor=False,
- fake_tensor=None,
- is_tensor=False,
- )
- tensor_list_proxy.node.meta["grapharg"] = grapharg
- # The following is very important for maintaining the "python object
- # <==> variable tracker" 1-to-1 mapping, which is mainly handled via
- # `side_effects`. Note that constructing `tensor_variable` above
- # already adds it to graph arg, but we never registered it with
- # `side_effects`. The preemptive `realize` calls here basically
- # does that registration (at the end of `self.__call__`).
- #
- # A slightly cleaner alternative is to register the
- # `tensor_variable`s above with `side_effects` directly, and just
- # return the `list_variable`, but that breaks some tensor-subclass
- # related tests like `test_inputs_aliasing_bytecode_stack_restore`,
- # because `tensor_variable` is constructed via
- # `handle_traced_output`, which doesn't really expect/handle tensor
- # subclass.
- #
- # Eventually, we expect to fix remove all of these by having Dynamo
- # auto-boxing inputs to the compiled graph, see
- # https://github.com/pytorch/pytorch/issues/153701.
- for vt in output:
- vt.realize()
- # type: ignore[arg-type]
- result = BaseListVariable.cls_for_instance(value)(output, source=self.source)
- if istype(value, (list, collections.deque)):
- return self.tx.output.side_effects.track_mutable(value, result)
- return result
- def wrap_tuple_iterator(self, value: tuple_iterator) -> VariableTracker:
- self.install_guards(GuardBuilder.TUPLE_ITERATOR_LEN)
- output = [
- VariableBuilder(self.tx, TupleIteratorGetItemSource(self.get_source(), i))(
- tuple_iterator_getitem(value, i)
- )
- for i in range(tuple_iterator_len(value))
- ]
- result = TupleIteratorVariable(output, source=self.source)
- return self.tx.output.side_effects.track_mutable(value, result)
- def wrap_range_iterator(self, value: range_iterator) -> VariableTracker:
- self.install_guards(GuardBuilder.RANGE_ITERATOR_MATCH)
- # Get all the values from the range iterator; no need to install guards
- # on items since `RANGE_ITERATOR_MATCH` guarantees the same items.
- items = [ConstantVariable.create(v) for v in copy.deepcopy(value)]
- result = ListIteratorVariable(items, source=self.source)
- return self.tx.output.side_effects.track_mutable(value, result)
- def wrap_slice_range(self, value: slice | range) -> SliceVariable | RangeVariable:
- items = [
- VariableBuilder(self.tx, AttrSource(self.get_source(), k))(
- getattr(value, k)
- )
- for k in ("start", "stop", "step")
- ]
- self.install_guards(GuardBuilder.TYPE_MATCH)
- if isinstance(value, slice):
- return SliceVariable(items, self.tx, source=self.source)
- else:
- return RangeVariable(items, source=self.source)
- def mark_static_input(self, value: torch.Tensor, guard: bool) -> None:
- from ..decorators import mark_static_address
- static_inputs_log.debug(
- "Marking static input %s, id: %s)", self.source.name, id(value)
- )
- mark_static_address(value, guard=guard)
- # Check if we've seen this tensor before and update graph metadata if needed
- # As long as this runs before AOT this is sound
- if value in self.tx.output.side_effects:
- var = self.tx.output.side_effects[value]
- # type: ignore[attr-defined]
- var.proxy.node.meta["tensor_dict"]["_dynamo_static_input_type"] = (
- # type: ignore[attr-defined]
- value._dynamo_static_input_type
- )
- def wrap_module(self, value: torch.nn.Module) -> VariableTracker:
- from ..eval_frame import OptimizedModule
- if len(value.__dict__) == 0:
- unimplemented(
- gb_type="Uninitialized nn.Module",
- context=typestr(value),
- explanation=f"Attempted to trace an uninitialized nn.Module of type {typestr(value)}.",
- hints=[
- *graph_break_hints.USER_ERROR,
- "Ensure your nn.Module instance has called `super().__init__()`.",
- ],
- )
- if istype(value, OptimizedModule):
- # Check if the optimized module was disabled
- if inspect.getattr_static(value.forward, "_torchdynamo_disable", False):
- # This bytecode is mostly of kind LOAD_ATTR or LOAD_METHOD. If
- # we graph break here, Dynamo does not know how to create
- # continuation functions for such bytecodes. So, we delay the
- # graph break to CALL_FUNCTION.
- msg = inspect.getattr_static(
- value.forward, "_torchdynamo_disable_msg", None
- )
- return DelayGraphBreakVariable(
- source=self.source,
- msg=f"Optimized `nn.Module` is wrapped with `torch.compiler.disable` (reason: {msg})",
- )
- self.install_guards(GuardBuilder.TYPE_MATCH)
- self.source = AttrSource(self.source, "_orig_mod")
- return self.wrap_module(value._orig_mod)
- if (
- isinstance(value, (torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM))
- and not config.allow_rnn
- ):
- unimplemented(
- gb_type="Attempted to wrap RNN, GRU, or LSTM",
- context=str(value),
- explanation="Dynamo does not support RNN, GRU, or LSTM.",
- hints=[
- "Set torch._dynamo.config.allow_rnn=True to enable experimental support for RNN, GRU, and LSTM in Dynamo",
- *graph_break_hints.SUPPORTABLE,
- ],
- )
- if getattr(value, "_is_fsdp_managed_module", False):
- # See note [Dynamo treats FSDP wrapped modules as UnspecializedNNModule]
- # in fully_sharded_data_parallel.py for more information
- # we can't do this assert inside FSDP constructor,
- # since we don't know yet whether dynamo will be used
- if not getattr(value, "_fsdp_use_orig_params", False):
- unimplemented(
- gb_type="FSDP with use_orig_params=False",
- context="",
- explanation="Dynamo only supports FSDP with use_orig_params=True",
- hints=[],
- )
- # Note on FSDP guarding
- # Eager FSDP already assumes (requires, but without enforcement)
- # that users don't mutate their model parameters/structure after
- # FSDP wrapping, because FSDP wouldn't notice or update its
- # FlatParams.
- #
- # Therefore, torch.compile can skip guarding on params or submodule
- # structure of fsdp_managed modules, by using FSDPNNModuleSource as
- # the guard source. This behavior is gated on
- # config.skip_fsdp_guards.
- self.install_guards(GuardBuilder.TYPE_MATCH)
- result = FSDPManagedNNModuleVariable(value, source=self.get_source())
- if not SideEffects.cls_supports_mutation_side_effects(type(value)):
- # don't allow STORE_ATTR mutation with custom __setattr__
- return result
- return self.tx.output.side_effects.track_object_existing(value, result)
- elif mutation_guard.is_dynamic_nn_module(value, self.tx.export):
- # created dynamically, don't specialize on it
- # Note [Tracing a torch.compiled function]
- # when make_fx tracing a compiled function, we need
- if isinstance(value, torch.fx.experimental.proxy_tensor._AttrProxy):
- # type: ignore[attr-defined]
- value = value.get_base()
- self.source = AttrProxySource(self.source)
- if torch._dynamo.config.inline_inbuilt_nn_modules:
- freezing = is_parameter_freezing()
- # Guard against the case where user may overwrite named parameters
- # / named buffers
- # NOTE: This is not likely to happen but worth guarding to avoid
- # exception
- if (
- callable(value.named_parameters)
- # type: ignore[attr-defined]
- and value.named_parameters.__func__
- is og_module_named_parameters_fn_ptr
- ):
- try: # catch TypeErrors in named_parameters() from unserializable nn modules
- # type: ignore[attr-defined]
- for _, p in value.named_parameters():
- self.mark_static_input(p, guard=freezing)
- except TypeError as e:
- raise_observed_exception(type(e), self.tx, args=list(e.args))
- if (
- callable(value.named_buffers)
- # type: ignore[attr-defined]
- and value.named_buffers.__func__ is og_module_named_buffers_fn_ptr
- ):
- try: # catch TypeErrors in named_parameters() from unserializable nn modules
- # type: ignore[attr-defined]
- for _, b in value.named_buffers():
- self.mark_static_input(b, guard=freezing)
- except TypeError as e:
- raise_observed_exception(type(e), self.tx, args=list(e.args))
- if freezing:
- # we need to add the module to tracing context
- # in order to allow its params to get invalidated
- # this will get cleaned up once compile ends
- self.tx.output.nn_modules[self.name] = value
- if (
- value.__module__.startswith(("torch.nn.modules", "torch.ao."))
- and not value.__module__.startswith("torch.nn.modules.container")
- ) or getattr(value.__class__, "_dynamo_marked_static", False):
- new_source = self.source
- if config.inline_inbuilt_nn_modules and (
- not self.tx.output.export or config.install_free_tensors
- ):
- # Export corner case - look at test_repros.py test_inlining_cornercase
- new_source = UnspecializedBuiltinNNModuleSource(self.source)
- result = UnspecializedBuiltinNNModuleVariable(value, source=new_source)
- install_guard(new_source.make_guard(GuardBuilder.TYPE_MATCH))
- else:
- new_source = self.source
- if config.inline_inbuilt_nn_modules and (
- not self.tx.output.export or config.install_free_tensors
- ):
- # Export corner case - look at test_repros.py test_inlining_cornercase
- new_source = UnspecializedNNModuleSource(self.source)
- result = UnspecializedNNModuleVariable(value, source=new_source)
- install_guard(new_source.make_guard(GuardBuilder.TYPE_MATCH))
- self.tx.output.add_fqn_info_for_inlined_modules(value, self.source)
- if not SideEffects.cls_supports_mutation_side_effects(type(value)):
- # don't allow STORE_ATTR mutation with custom __setattr__
- return result
- return self.tx.output.side_effects.track_object_existing(value, result)
- elif issubclass(
- value.__class__, torch.nn.parallel.distributed.DistributedDataParallel
- ):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- return UnspecializedNNModuleVariable(value, source=self.get_source())
- else:
- return self.tx.output.register_attr_or_module(
- value,
- self.name,
- source=self.get_source(),
- # Guards are added inside register_attr_or_module
- )
- def wrap_literal(self, value: object) -> VariableTracker:
- if type(value) is int:
- assert isinstance(value, int)
- # allowlist has higher precedence over specialization control.
- if is_dynamic_source(self.source.name):
- log.debug("%s marked dynamic via source whitelist", self.source.name)
- return self.wrap_symint(value, dynamism=DimDynamic.DYNAMIC)
- if is_unbacked_source(self.source.name):
- log.debug("%s marked unbacked via source whitelist", self.source.name)
- return self.wrap_symint(value, dynamism=DimDynamic.UNBACKED)
- if not config.specialize_int:
- # unspecializing int by default, but still
- # specialize for the following conditions
- if is_int_specialization_case(value, self.source):
- recompile_hint = None
- if (
- self.source.guard_source.is_unspecialized_builtin_nn_module()
- or self.source.guard_source.is_unspecialized_nn_module()
- ):
- # This means that it is an integer from a NN module.
- # Dynamo considers nn module int attributes to be static
- # (a good heuristic). But a user might want to mark the
- # int attribute to be a symint, so track this integer
- # for recompilation later.
- recompile_hint = (
- "torch.compile considers integer attributes of the nn.Module to be static. "
- "If you are observing recompilation, you might want to make this integer dynamic "
- "using torch._dynamo.config.allow_unspec_int_on_nn_module = True, or convert this "
- "integer into a tensor."
- )
- process_automatic_dynamic(
- self.tx,
- self.source.name,
- FrameStateSizeEntry.make_scalar(value),
- is_unspecialized_nn_module=self.source.guard_source.is_unspecialized_nn_module(),
- )
- self.install_guards(
- functools.partial(
- GuardBuilder.EQUALS_MATCH, recompile_hint=recompile_hint
- )
- )
- return ConstantVariable.create(value=value, source=self.source)
- return self._wrap_lazy_constant(value, self._wrap_symint_for_lazy)
- return self._wrap_lazy_constant(value)
- elif type(value) is float:
- assert isinstance(value, float)
- if not config.specialize_float:
- return self._wrap_lazy_constant(value, self._wrap_symfloat_for_lazy)
- return self._wrap_lazy_constant(value)
- elif type(value) in (bool, str):
- assert isinstance(value, (bool, str))
- return self._wrap_lazy_constant(value)
- else:
- self.install_guards(GuardBuilder.CONSTANT_MATCH)
- result = ConstantVariable.create(value=value, source=self.source)
- if isinstance(value, (list, set)):
- return self.tx.output.side_effects.track_mutable(value, result)
- return result
- def _wrap_symint_for_lazy(self, value: int) -> VariableTracker:
- return self.wrap_symint(value)
- def _wrap_symfloat_for_lazy(self, value: float) -> VariableTracker:
- return self.wrap_symfloat(value)
- @overload
- def _wrap_lazy_constant(
- self,
- value: int,
- wrap_fn: Callable[[int], VariableTracker],
- ) -> VariableTracker: ...
- @overload
- def _wrap_lazy_constant(
- self,
- value: float,
- wrap_fn: Callable[[float], VariableTracker],
- ) -> VariableTracker: ...
- @overload
- def _wrap_lazy_constant(
- self,
- value: Union[int, float, bool, str],
- wrap_fn: None = None,
- ) -> VariableTracker: ...
- def _wrap_lazy_constant(
- self,
- value: Union[int, float, bool, str],
- wrap_fn: Optional[Callable[[Any], VariableTracker]] = None,
- ) -> VariableTracker:
- """Wrap a primitive constant, deferring guard installation if allowed."""
- if not self.allow_lazy_constant:
- if wrap_fn is not None:
- return wrap_fn(value)
- self.install_guards(GuardBuilder.CONSTANT_MATCH)
- return ConstantVariable.create(value=value, source=self.source)
- return LazyConstantVariable.create(value, source=self.source)
- def assert_not_wrapped_by_this_graph(self, value: torch.Tensor) -> None:
- if is_fake(value) and maybe_get_fake_mode(value) is self.tx.fake_mode:
- raise InternalTorchDynamoError(
- "Cannot wrap a Tensor that has already been",
- "wrapped by this instance of Dynamo",
- )
- def wrap_tensor(self, value: torch.Tensor) -> VariableTracker:
- source = self.get_source()
- # We cannot already be tracking the tensor, which implies
- # it would have already been wrapped
- assert value not in self.tx.output.side_effects
- is_static_input = get_static_address_type(value) is not None
- if (
- config.inline_inbuilt_nn_modules
- and not is_static_input
- and (
- isinstance(value, torch.nn.Parameter)
- # mark tensor attributes of nn modules static. This is done to keep inline_inbuilt_nn_modules behavior
- # compatible with previous behavior.
- or (source and source.guard_source.is_unspecialized_nn_module())
- )
- ):
- self.mark_static_input(value, guard=is_parameter_freezing())
- is_static_input = True
- # Install any tensors which are "free" variables; that is:
- # 1. Globals
- # 2. NonLocals
- # 3. tensors that are attributes of nn module
- should_install_free_tensor = config.install_free_tensors and (
- is_from_global_source(source)
- or is_from_nonlocal_source(source)
- or is_from_unspecialized_nn_module_source(source)
- )
- make_graph_attribute = is_static_input and (
- not config.inline_inbuilt_nn_modules
- or is_parameter_freezing()
- or torch._dynamo.config.prepare_freezing
- )
- if should_install_free_tensor or (
- (source.guard_source.is_specialized_nn_module() or make_graph_attribute)
- and not source.guard_source.is_fsdp_module()
- ):
- self.assert_not_wrapped_by_this_graph(value)
- return self.tx.output.register_attr_or_module(
- value, self.name, source=source
- )
- if get_static_address_type(value) == "guarded":
- # If it's a guarded tensor, we can install the parameter directly
- # into the Fx graph instead of lifting it as an input. Lifting
- # offers no benefit, such as regional compilation, since we still
- # guard on the tensor's ID. Moreover, installing it in the Fx graph
- # eliminates the pre-graph bytecode required to extract the tensor
- # from locals/globals, reducing overhead. This can lead to
- # significant cost savings, especially for optimizers handling many
- # tensors.
- self.install_guards(GuardBuilder.ID_MATCH)
- self.assert_not_wrapped_by_this_graph(value)
- return self.tx.output.register_attr_or_module(
- value, self.name, source=source
- )
- if is_constant_source(source):
- self.assert_not_wrapped_by_this_graph(value)
- return self.tx.output.register_attr_or_module(
- value,
- re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
- source=source,
- # Guards are added inside register_attr_or_module
- )
- # NB: this just says we accessed a tensor from the same source again
- # (e.g., a tensor lives in a global foo, and we LOAD_GLOBAL it twice).
- # This is distinct from two distinct sources mapping to the same
- # Tensor (per id())! No guard is necessary here. See below for the
- # other case.
- is_duplicate_tensor = source in self.tx.output.input_source_to_var
- if is_duplicate_tensor:
- return self.tx.output.input_source_to_var[source]
- options = {}
- subclass_type = infer_subclass_type(value)
- if subclass_type is not None:
- self.install_guards(GuardBuilder.TYPE_MATCH)
- if get_static_address_type(value) == "guarded":
- self.install_guards(GuardBuilder.ID_MATCH)
- # By this point, we should have deduplicated all tensors
- self.assert_not_wrapped_by_this_graph(value)
- if (
- isinstance(value, torch.Tensor)
- and value.is_nested
- and not isinstance(value, torch.nested._internal.nested_tensor.NestedTensor)
- ):
- unimplemented(
- gb_type="Attempted to wrap strided NestedTensor",
- context="",
- explanation="torch.compile does not support strided NestedTensor",
- hints=[],
- )
- # TODO(pearu,sparse-team) - Add the corresponding SPARSE_TENSOR_MATCH guards
- if (
- isinstance(value, torch.Tensor)
- and is_sparse_any(value)
- and (not self.tx.export or not config.capture_sparse_compute)
- ):
- # A hot fix for sparse tensors + torch.compile. Support for
- # export + sparsity is being added but we need to create
- # SPARSE_TENSOR_GUARDS for guards to work properly.
- unimplemented(
- gb_type="Attempted to wrap sparse Tensor",
- context="",
- explanation="torch.compile does not support sparse Tensors",
- hints=[*graph_break_hints.SPARSE_TENSOR],
- )
- if (
- safe_has_grad(value)
- and safe_grad(value) is not None
- # type: ignore[attr-defined]
- and value.dtype != safe_grad(value).dtype
- ):
- safe_grad_val = safe_grad(value)
- grad_str = str(safe_grad_val.dtype) if safe_grad_val is not None else "None"
- unimplemented(
- gb_type="dtype mismatch between tensor and its gradient",
- context=f"tensor dtype: {value.dtype}; grad dtype: {grad_str}",
- explanation="Inconsistent dtype between tensor and its gradient. "
- "This can happen in FSDP and crashes meta tensor creation.",
- hints=[*graph_break_hints.SUPPORTABLE],
- )
- # tx.output has multiple tracers if we're introspecting HigherOrderOperator.
- # When we've discovered an untracked tensor, then we actually need
- # to get Dynamo to track the tensor (which is what this function does)
- # and put it as a graph input on the root tracer. Later on,
- # if the input is actually used in the body of the HigherOrderOperator,
- # then the relevant SubgraphTracer will lift it to being an input of
- # the subgraph.
- # See NOTE [HigherOrderOperator tracing design] for more details.
- example_value = wrap_to_fake_tensor_and_record(
- value, tx=self.tx, is_tensor=True, source=source
- )
- tensor_proxy = self.tx.output.root_tracer.create_graph_input(
- re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
- type(value),
- example_value,
- source=source,
- )
- cache_real_value_when_export(self.tx, tensor_proxy, value)
- tensor_variable = wrap_fx_proxy(
- tx=self.tx,
- proxy=tensor_proxy,
- example_value=example_value,
- subclass_type=subclass_type,
- source=source,
- **options,
- )
- if value._is_view():
- # If value is a view, add its base tensor to the tracked fakes list.
- # This is so we are able to access the correct source for its symbolic
- # shape values, in case we need them.
- wrap_to_fake_tensor_and_record(
- value._base,
- tx=self.tx,
- source=AttrSource(source, "_base"),
- is_tensor=True,
- )
- guard_type = GuardBuilder.TENSOR_MATCH
- if isinstance(source, GradSource) and is_from_optimizer_source(source):
- guard_type = GuardBuilder.NOT_NONE_MATCH
- is_dtensor = torch.distributed.is_available() and isinstance(
- value, torch.distributed.tensor.DTensor
- )
- if not is_dtensor:
- # We guard on the _local_tensor and the _spec, and therefore we dont
- # have to guard on the outer DTensor.
- self.install_guards(
- functools.partial(
- guard_type,
- value=(
- value
- if isinstance(source, NumpyTensorSource)
- else TensorWeakRef(value)
- ),
- )
- )
- # We install TYPE_MATCH guards for traceable wrapper subclass object,
- # and recursively install corresponding guard for each inner attribute.
- if is_traceable_wrapper_subclass(value):
- # Tensor subclass guards are very expensive because they are
- # implemented in Python. Since DTensor is PyTorch-maintained class,
- # we can skip a lot of these guards.
- if is_dtensor:
- self.install_guards(GuardBuilder.TYPE_MATCH)
- # The inner tensor name is always _local_tensor. If its not, we
- # raise assertion to update the check accordingly.
- inner_tensor_name = value.__tensor_flatten__()[0][0]
- if inner_tensor_name != "_local_tensor":
- raise RuntimeError(
- "Expecting Dtensor inner tensor name to be _local_tensor"
- )
- # Now selectively guard on the flattening context
- flattening_ctx = value.__tensor_flatten__()[1]
- # This is supposed to be (self._spec, self.requires_grad)
- if not (
- len(flattening_ctx) == 2
- and flattening_ctx[0] == value._spec
- and flattening_ctx[1] == value.requires_grad
- ):
- # If not, raise an assertion to update to the new guards
- raise RuntimeError(
- "Expecting Dtensor flattening ctx to be _spec, requires_grad"
- )
- # Guard on the dtensor spec
- install_guard(
- AttrSource(self.source, "_spec").make_guard(
- GuardBuilder.DTENSOR_SPEC_MATCH
- )
- )
- # Move this to C++
- install_guard(
- AttrSource(self.source, "requires_grad").make_guard(
- GuardBuilder.EQUALS_MATCH
- )
- )
- else:
- self.install_guards(GuardBuilder.TENSOR_SUBCLASS_METADATA_MATCH)
- self.install_guards(GuardBuilder.TYPE_MATCH)
- install_guard(
- SubclassAttrListSource(source).make_guard(GuardBuilder.EQUALS_MATCH)
- )
- attrs, _ = value.__tensor_flatten__()
- for attr in attrs:
- inner_value = getattr(value, attr)
- inner_source = AttrSource(self.source, attr)
- LazyVariableTracker.realize_all(
- VariableBuilder(self.tx, inner_source)(inner_value)
- )
- self.tx.output.input_source_to_var[source] = tensor_variable
- assert "tensor_dict" not in tensor_proxy.node.meta
- tensor_proxy.node.meta["tensor_dict"] = _extract_tensor_dict(value)
- # Note: this information is conveyed via subclass_type now
- # type: ignore[attr-defined]
- fake_tensor_value = tensor_variable.proxy.node.meta["example_value"]
- if maybe_get_fake_mode(fake_tensor_value) is not self.tx.fake_mode:
- raise InternalTorchDynamoError("Wrapped Tensor must be this graph's fake")
- grapharg = GraphArg(source, value, False, fake_tensor_value)
- tensor_proxy.node.meta["grapharg"] = grapharg
- return tensor_variable
- def wrap_numpy_ndarray(self, value: Any) -> VariableTracker:
- assert np is not None
- assert isinstance(value, np.ndarray)
- source = NumpyTensorSource(self.get_source())
- from torch._numpy import _util
- readonly = not value.flags.writeable
- if readonly:
- try:
- value.flags.writeable = True
- except ValueError:
- # One can not easily make nditer elements writable,
- # but warning is not the end of the world
- assert isinstance(value.base, np.nditer)
- tensor_value = None
- with torch_function_mode_stack_state_mgr.temp_restore_stack():
- try:
- tensor_value = _util._try_convert_to_tensor(value)
- if readonly:
- from torch._prims_common import clone_preserve_strides
- tensor_value = clone_preserve_strides(tensor_value)
- except NotImplementedError as e:
- # failed to convert to tensor, graph break
- unimplemented(
- gb_type="failed to convert numpy.ndarray to Tensor",
- context=str(value),
- explanation="Exception encountered when attempting to convert numpy.ndarray to Tensor",
- hints=[],
- from_exc=e,
- )
- assert tensor_value is not None
- # We do this because we want the full behavior of guarding the numpy ndarray as if it were
- # a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here
- # that there's not another great way to do this atm.
- # This creates the right graphargs, as well as registration for guards in tensor names and shape env.
- LazyVariableTracker.realize_all(VariableBuilder(self.tx, source)(tensor_value))
- example_value = wrap_to_fake_tensor_and_record(
- tensor_value,
- tx=self.tx,
- is_tensor=False,
- source=source,
- )
- proxy = self.tx.output.root_tracer.create_graph_input(
- re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
- type(tensor_value),
- example_value,
- source=source,
- )
- cache_real_value_when_export(self.tx, proxy, tensor_value)
- options = {"source": source}
- numpy_ndarray_variable = wrap_fx_proxy_cls(
- target_cls=NumpyNdarrayVariable,
- tx=self.tx,
- proxy=proxy,
- example_value=example_value,
- subclass_type=None,
- **options,
- )
- self.tx.output.input_source_to_var[source] = numpy_ndarray_variable
- # type: ignore[attr-defined]
- example_value = numpy_ndarray_variable.proxy.node.meta["example_value"]
- # pass_arg_as_tensor should be true because we are wrapping a np.ndarray as argument input, and it needs to be
- # converted to a tensor.
- grapharg = GraphArg(
- source,
- tensor_value,
- pass_arg_as_tensor=True,
- fake_tensor=example_value,
- is_tensor=True,
- example_strong_ref=tensor_value,
- )
- proxy.node.meta["grapharg"] = grapharg
- # TODO - Why do we need to set the source of the np ndarray vt back to
- # original source. Many tests fails.
- numpy_ndarray_variable.source = self.source
- return numpy_ndarray_variable
- def wrap_symint(
- self,
- value: int,
- dynamism: DimDynamic | None = None,
- context: SymIntSymbolicContext | None = None,
- ) -> VariableTracker:
- assert type(value) is int
- if self.name in self.tx.output.unspec_variable_map:
- return self.tx.output.unspec_variable_map[self.name]
- shape_env = self.tx.output.shape_env
- if TracingContext.get().force_unspec_int_unbacked_size_like:
- wrapped_value = shape_env.create_unbacked_symint()
- _constrain_range_for_size(wrapped_value)
- self.tx.output.tracked_fakes.append(
- TrackedFake(wrapped_value, self.source, None)
- )
- # NB: We do not do float. For motivation, see
- # https://docs.google.com/document/d/1INSCdYu1PxXcr43HrD82OudeEuS-qxQe1yZmLg2wy6A/edit
- # but the general idea is that we generate kernels that can
- # take unspecialized floats and use them in sizevar computation
- elif not is_constant_source(self.get_source()):
- if dynamism is None and torch._dynamo.config.specialize_int:
- # If specialize_int is False, also return
- # a constant (but this should have been handled
- # in the caller, TBH). But if `dynamism` is set, then actually
- # turn it into a symint
- self.install_guards(GuardBuilder.CONSTANT_MATCH)
- return ConstantVariable.create(value=value, source=self.source)
- name = self.source.name
- frame_state_entry = process_automatic_dynamic(
- self.tx,
- name,
- FrameStateSizeEntry.make_scalar(value),
- is_unspecialized_nn_module=self.source.guard_source.is_unspecialized_nn_module(),
- )
- # TODO: This should be dynamic, as we in general do not
- # know if bare integers are actually going to be sizevars
- # and it is inappropriate to eagerly duck size them with
- # real sizevars
- normalized_source_name = normalize_source_name(self.source.name)
- base_source = self.source
- if isinstance(base_source, ChainedSource):
- base_source = base_source.get_base()
- if dynamism is not None:
- dynamic_dim = dynamism
- elif (
- config.automatic_dynamic_shapes
- and frame_state_entry.scalar is auto_dynamic
- ):
- set_feature_use("dynamo.automatic_dynamic_shapes", True)
- dynamic_dim = get_automatic_dynamic_shapes_mark_as()
- elif (
- isinstance(base_source, LocalSource)
- and base_source.dynamism is not None
- # pyrefly: ignore[no-matching-overload]
- and dict(base_source.dynamism).get(normalized_source_name, {0: False})[
- 0
- ]
- ) or not config.assume_static_by_default:
- dynamic_dim = DimDynamic.DYNAMIC
- else: # assume_static_by_default
- # TODO: dynamic_dim = DimDynamic.STATIC should work but
- # for some reason it doesn't
- if frame_state_entry.scalar is auto_dynamic:
- set_feature_use("dynamo.automatic_dynamic_shapes", False)
- self.install_guards(GuardBuilder.CONSTANT_MATCH)
- return ConstantVariable.create(value=value)
- wrapped_value = shape_env.create_unspecified_symint_and_symbol(
- value,
- source=self.source,
- dynamic_dim=dynamic_dim,
- )
- self.tx.output.tracked_fakes.append(
- TrackedFake(wrapped_value, self.source, context)
- )
- else:
- assert is_constant_source(self.get_source())
- # TODO: Do I actually need guard for constant source?
- self.install_guards(GuardBuilder.CONSTANT_MATCH)
- return ConstantVariable.create(value=value, source=self.source)
- assert not isinstance(self.get_source(), RandomValueSource)
- install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH))
- options = {"source": self.get_source()}
- proxy = self.tx.output.root_tracer.create_graph_input(
- re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
- type(wrapped_value),
- wrapped_value,
- source=self.get_source(),
- )
- sym_expr = wrapped_value.node.expr
- assert isinstance(sym_expr, sympy.Symbol), f"{sym_expr} is not a basic Symbol."
- self.tx.output.root_tracer.bound_symbols[sym_expr] = proxy
- unspec_var = SymNodeVariable.create(self.tx, proxy, wrapped_value, **options)
- # type: ignore[assignment]
- self.tx.output.unspec_variable_map[self.name] = unspec_var
- if not is_constant_source(self.get_source()):
- proxy.node.meta["grapharg"] = GraphArg(
- self.get_source(),
- wrapped_value,
- pass_arg_as_tensor=False,
- fake_tensor=None,
- is_tensor=False,
- example_strong_ref=wrapped_value,
- )
- return unspec_var
- def wrap_symfloat(self, value: float) -> VariableTracker:
- # SymFloat wrapping is special. We first wrap it in the same way we
- # do an unspecialized primitive, and then we item() it into a
- # SymFloat. Removal of the item() call is left to a later FX pass,
- # mostly because that pass is more easily done after we have lowered
- # to ATen ops. (Dynamo doesn't do decomposition right now).
- if self.name in self.tx.output.unspec_variable_map:
- return self.tx.output.unspec_variable_map[self.name]
- frame_state_entry = process_automatic_dynamic(
- self.tx,
- self.source.name,
- # type: ignore[arg-type]
- FrameStateSizeEntry.make_scalar(value),
- is_unspecialized_nn_module=self.source.guard_source.is_unspecialized_nn_module(),
- )
- # NB: we specialize on nan input, because our guard modeling in
- # ShapeEnv cannot deal with nan
- if (
- torch._dynamo.config.specialize_float
- or is_constant_source(self.get_source())
- or math.isnan(value)
- or math.isinf(value)
- # We don't support cudagraphs for now. Without this cudagraphs
- # break because they expect all cuda inputs but our tensorified
- # float will be a f64[] cpu tensor. Fixes the following test
- # when specialize_float=False
- # python test/inductor/test_compiled_optimizers.py CompiledOptimizerTests.test_rmsprop_weight_decay_maximize_capturable_cuda # noqa: B950
- or torch._inductor.config.triton.cudagraphs
- or justknobs_check("pytorch/compiler:unspecialize_float_killswitch", False)
- or (
- config.assume_static_by_default
- and frame_state_entry.scalar is not auto_dynamic
- )
- ):
- self.install_guards(GuardBuilder.CONSTANT_MATCH)
- return ConstantVariable.create(value=value, source=self.source)
- # NB: At the point we've gotten here, we don't assume static by
- # default. Since we have a guard mechanism, there isn't really any
- # downside to trying to be dynamic for float all the time. Unlike
- # ints, this won't make codegen perf worse. Modest cost to compile
- # time.
- wrapped_value = torch.tensor(value, dtype=torch.float64)
- # We don't support specializing floats for grad checking tensors
- # See https://github.com/pytorch/pytorch/pull/140828 for more
- # context.
- if torch._C._functorch.is_gradtrackingtensor(wrapped_value):
- self.install_guards(GuardBuilder.CONSTANT_MATCH)
- return ConstantVariable.create(value=value, source=self.source)
- # TODO: Switch RandomValueSource over to use this, this is more
- # accurate
- assert not isinstance(self.get_source(), RandomValueSource)
- install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH))
- # The FloatTensorSource here is just for pedantic correctness: if you
- # guard against an UnspecializedPythonVariable, you need to guard
- # against the tensor-ified version of the local, otherwise it's not a
- # Tensor. However, we never let the UnspecializedPythonVariable escape
- # here, so there should never actually be any guards against this
- # source.
- source = FloatTensorSource(self.get_source())
- options = {"source": source, "raw_value": value}
- # TODO: Maybe the tensor-ification should be built into the source,
- # rather than by special pattern match
- example_value = wrap_to_fake_tensor_and_record(
- wrapped_value, tx=self.tx, is_tensor=False, source=source
- )
- proxy = self.tx.output.root_tracer.create_graph_input(
- re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
- type(wrapped_value),
- example_value,
- source=source,
- )
- cache_real_value_when_export(self.tx, proxy, wrapped_value)
- unspec_var = wrap_fx_proxy_cls(
- UnspecializedPythonVariable,
- tx=self.tx,
- proxy=proxy,
- example_value=example_value,
- subclass_type=None,
- **options,
- )
- assert isinstance(unspec_var, UnspecializedPythonVariable)
- self.tx.output.unspec_variable_map[self.name] = unspec_var
- if self.tx.export and not isinstance(self.get_source(), LocalSource):
- raise AssertionError(
- f"Dynamo attempts to add additional input during export: value={wrapped_value}, source={self.get_source()}"
- )
- fake_tensor_value = None
- example_value = unspec_var.proxy.node.meta["example_value"]
- assert is_fake(example_value)
- fake_tensor_value = example_value
- # type: ignore[attr-defined]
- assert fake_tensor_value.fake_mode is self.tx.fake_mode, (
- f"fake mode ({fake_tensor_value.fake_mode}) from fake tensor metadata doesn't match mode"
- "({self.tx.fake_mode}) from InstructionTranslator"
- )
- # There's something a bit incoherent about pass_arg_as_tensor,
- # specifically regarding sources.
- #
- # Specifically, suppose we have "x: float" local argument. We
- # eventually end up with an UnspecializedPythonVariable denoting
- # torch.as_tensor(x)... but it's source is still L['x'] (which if you
- # accessed it directly is a float!) So you gotta be careful when
- # setting up your guards, because it's still going to be a float at
- # this point, the conversion happens only precisely at the point we're
- # actually calling the FX graph. This happens to be what we want for
- # shape guard generation, but it's kind of unintuitive.
- proxy.node.meta["grapharg"] = GraphArg(
- self.get_source(),
- wrapped_value,
- pass_arg_as_tensor=True,
- # type: ignore[arg-type]
- fake_tensor=fake_tensor_value,
- is_tensor=False,
- example_strong_ref=wrapped_value,
- )
- # Directly do item to bypass capture_scalar_outputs
- r = wrap_fx_proxy(
- self.tx,
- self.tx.output.create_proxy(
- "call_method",
- "item",
- *proxy_args_kwargs([unspec_var], {}),
- ),
- )
- # type: ignore[attr-defined]
- self.tx.output.tracked_fakes.append(TrackedFake(r.sym_num, self.source, None))
- get_metrics_context().set("tensorify_float_attempt", True, overwrite=True)
- return r
- def wrap_unspecialized_primitive(self, value: Any) -> VariableTracker:
- if self.name in self.tx.output.unspec_variable_map:
- return self.tx.output.unspec_variable_map[self.name]
- wrapped_value = torch.tensor(value)
- if not isinstance(self.get_source(), RandomValueSource):
- install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH))
- options = {"source": self.get_source()}
- options.update({"raw_value": value})
- example_value = wrap_to_fake_tensor_and_record(
- wrapped_value, tx=self.tx, is_tensor=False, source=self.get_source()
- )
- proxy = self.tx.output.root_tracer.create_graph_input(
- re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
- type(wrapped_value),
- example_value,
- source=self.get_source(),
- )
- cache_real_value_when_export(self.tx, proxy, wrapped_value)
- unspec_var = wrap_fx_proxy_cls(
- UnspecializedPythonVariable,
- tx=self.tx,
- proxy=proxy,
- example_value=example_value,
- subclass_type=None,
- **options,
- )
- # type: ignore[assignment]
- self.tx.output.unspec_variable_map[self.name] = unspec_var
- if not is_constant_source(self.get_source()):
- if self.tx.export and not isinstance(self.get_source(), LocalSource):
- raise AssertionError(
- f"Dynamo attempts to add additional input during export: value={wrapped_value}, source={self.get_source()}"
- )
- fake_tensor_value = None
- if unspec_var.is_python_constant():
- # TODO: when can this happen?
- example_value = unspec_var.as_python_constant()
- else:
- # type: ignore[attr-defined]
- example_value = unspec_var.proxy.node.meta["example_value"]
- assert is_fake(example_value)
- fake_tensor_value = example_value
- # type: ignore[attr-defined]
- assert fake_tensor_value.fake_mode is self.tx.fake_mode, (
- f"fake mode ({fake_tensor_value.fake_mode}) from fake tensor metadata doesn't match mode"
- "({self.tx.fake_mode}) from InstructionTranslator"
- )
- proxy.node.meta["grapharg"] = GraphArg(
- self.get_source(),
- wrapped_value,
- pass_arg_as_tensor=True,
- # type: ignore[arg-type]
- fake_tensor=fake_tensor_value,
- is_tensor=False,
- example_strong_ref=wrapped_value,
- )
- return unspec_var
- def _dataclasses_fields_lambda(obj: VariableTracker) -> TupleVariable:
- value = None
- if isinstance(obj, UserDefinedObjectVariable):
- value = obj.value
- else:
- unimplemented(
- gb_type="dataclass fields failure",
- context=f"obj: {obj}; variable type: {type(obj)}",
- explanation=f"Dataclass fields handling fails for {obj}. Expected it to be a user-defined object.",
- hints=[],
- )
- assert value is not None
- items = []
- # type: ignore[arg-type]
- for field in dataclasses.fields(value):
- source = None
- if obj.source:
- base_src = AttrSource(obj.source, "__dataclass_fields__")
- source = DictGetItemSource(base_src, field.name)
- items.append(UserDefinedObjectVariable(field, source=source))
- # pyrefly: ignore [bad-argument-type]
- return TupleVariable(items)
- def _clone_input(value: Any, fake_mode: FakeTensorMode | None) -> Any:
- if isinstance(value, torch.Tensor):
- # tensor subclasses will not be converted to FakeTensors and need to be cloned
- if not (
- isinstance(value, FakeTensor)
- or (
- # Is functional tensor fakeified by this instance of Dynamo
- torch._is_functional_tensor(value)
- and maybe_get_fake_mode(value) is fake_mode
- )
- or value.is_nested
- ):
- # NB: ensure strides are preserved
- value = clone_input(value)
- return value
- def wrap_fx_proxy(
- tx: "InstructionTranslatorBase",
- proxy: Any,
- example_value: Any | None = None,
- subclass_type: type | None = None,
- **options: Any,
- ) -> VariableTracker:
- kwargs = {
- "tx": tx,
- "proxy": proxy,
- "example_value": example_value,
- "subclass_type": subclass_type,
- **options,
- }
- if subclass_type is None:
- # pyrefly: ignore[bad-argument-type]
- return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
- else:
- # pyrefly: ignore[bad-argument-type]
- result = wrap_fx_proxy_cls(target_cls=TensorWithTFOverrideVariable, **kwargs)
- # type: ignore[attr-defined]
- result.install_global(tx)
- return result
- def cache_real_value_when_export(
- tx: "InstructionTranslatorBase", proxy: Any, example_value: Any
- ) -> None:
- if tx.export:
- # The legacy behavior for real value cache with subclasses was
- # to perform a clone WITHOUT preserving the subclass. It's
- # not entirely clear this is what you actually want though.
- with torch._C.DisableTorchFunctionSubclass():
- proxy.tracer.real_value_cache[proxy.node] = _clone_input(
- example_value, tx.fake_mode
- )
- # Note: Unfortunate split due to some gross classes existing that subclass TensorVariable
- # Should be compositional instead
- #
- # This is a horribly complicated function that does too many things, to
- # explain what it does, let's first talk about the classic usage wrap_fx_proxy
- # for a TensorVariable. There are two primary modes of use:
- #
- # 1. Wrapping a pre-existing Tensor. In this case, example_value is set
- # to the pre-existing Tensor. (Note that this example_value will NOT
- # be the final example_value we put into node.meta['example_value'],
- # instead it is converted into a fake tensor using
- # wrap_to_fake_tensor_and_record and registered as a graph input.)
- #
- # 2. "Wrapping" the result of some Tensor operation Dynamo traced over. In
- # this case, example_value is None (and we are going to figure it out
- # ourselves using FakeTensors, via get_fake_value, which will run
- # the operation represented by the (singular!) FX node referenced by
- # the passed in proxy.)
- #
- # The expectation is you end up with a Tensor output, and everything is
- # straightforwardly traced into the graph.
- #
- # In all cases, the returned `TensorVariable` subclass will have an `example_value`
- # and that `example_value` must be a `FakeTensor` produced by the currently running
- # instance of Dynamo.
- #
- # Upon closer inspection, you may notice that there are a slurry of non-Tensor
- # output cases in handle_traced_output. What gives? Well, we sometimes trace operations into the
- # graph that don't involve tensors.
- #
- # * Some operators return tuples; we need to recursively handle their
- # contents
- #
- # * Some operators have side effects that will affect subsequent AOTAutograd
- # tracing but don't otherwise return anything.
- #
- # * Some operators return symbolic ints/floats/bools which can go in the
- # graph and be traced (but only if they're actually symbolic! If they're
- # static you don't want to put them in the graph, which means you
- # shouldn't call this function.)
- #
- # The common theme is that you only use this function WHEN YOU ARE TRACING
- # SOMETHING INTO THE GRAPH. This is sort of obvious, because you can't call
- # this function without a proxy.
- def wrap_fx_proxy_cls(
- target_cls: type[VTTypeAlias],
- tx: "InstructionTranslatorBase",
- proxy: Any,
- example_value: Any | None = None,
- subclass_type: type | None = None,
- **options: Any,
- ) -> VTTypeAlias:
- if example_value is None:
- out: VTTypeAlias = _wrap_fx_proxy(
- target_cls, tx, proxy, example_value, subclass_type, **options
- )
- elif isinstance(example_value, torch.Tensor):
- out = _wrap_fx_preexisting_tensor(
- target_cls, tx, proxy, example_value, subclass_type, **options
- )
- else:
- # This will skip tracing an op and recursively reinvoke wrap_fx_proxy_cls on supported
- # data structures. In essence this just handles tracing some other value which may
- # contain Fake Tensors or is otherwise proxyable.
- # pyrefly: ignore[bad-assignment]
- out = handle_traced_output(
- example_value, tx, proxy, options, subclass_type, target_cls
- )
- if (
- isinstance(
- out,
- (
- torch._dynamo.variables.TensorVariable,
- torch._dynamo.variables.SymNodeVariable,
- ),
- )
- and proxy.node.op != "placeholder"
- ):
- tx.output.current_tracer.record_tensor_or_symint_vt(out)
- return out
- # This is 1 above (wrapping a preexisting tensor)
- def _wrap_fx_preexisting_tensor(
- target_cls: type[VTTypeAlias],
- tx: "InstructionTranslatorBase",
- proxy: torch.fx.Proxy,
- tensor: torch.Tensor,
- subclass_type: type | None = None,
- **options: Any,
- ) -> VTTypeAlias:
- from ..symbolic_convert import InstructionTranslatorBase
- assert isinstance(tensor, torch.Tensor), (
- f"_wrap_fx_preexisting_tensor expected tensor, got {type(tensor)}"
- )
- assert isinstance(tx, InstructionTranslatorBase)
- if "guards" in options and options["guards"] is not None:
- tx.output.guards.update(options["guards"])
- # Placeholders always carry example_value in node.meta.
- # non-placeholders always have no example_value in node.meta
- if proxy.node.op == "placeholder":
- assert "example_value" in proxy.node.meta, (
- f"placeholder {proxy} doesn't have 'example_value' in node.meta"
- )
- else:
- assert "example_value" not in proxy.node.meta, (
- f"{proxy.node.meta['example_value']}"
- )
- # See NOTE: [Deferring tensor pack/unpack hooks until runtime]
- with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
- # Handle recursive calls here
- if maybe_get_fake_mode(tensor) is tx.fake_mode:
- pass
- else:
- cache_real_value_when_export(tx, proxy, tensor)
- if tx.export:
- # The legacy behavior for real value cache with subclasses was
- # to perform a clone WITHOUT preserving the subclass. It's
- # not entirely clear this is what you actually want though.
- with torch._C.DisableTorchFunctionSubclass():
- # type: ignore[attr-defined]
- proxy.tracer.real_value_cache[proxy.node] = _clone_input(
- tensor, tx.fake_mode
- )
- # NB: If we're ignoring subclass, then the expectation is you will
- # take the returned TensorVariable and wrap it into a more
- # accurate TensorVariable that is able to track subclass-ness;
- # otherwise this is wrong!
- kwargs = {
- "is_tensor": target_cls
- in (TensorVariable, TensorWithTFOverrideVariable),
- }
- assert "source" in options and options["source"] is not None
- kwargs["source"] = options["source"]
- # pyrefly: ignore[missing-argument, bad-argument-type]
- tensor = wrap_to_fake_tensor_and_record(tensor, tx=tx, **kwargs)
- if tensor.device.type != "meta" and (
- maybe_get_fake_mode(tensor) is not tx.fake_mode
- ):
- raise InternalTorchDynamoError(
- "`tensor` needs to be a `FakeTensor`"
- f"wrapped by this instance of Dynamo. Found: {tensor}"
- )
- return construct_tensor_variable(
- target_cls, tx, proxy, tensor, subclass_type, options
- )
- # This is 2 in the above comment (wrapping the output of a traced op)
- def _wrap_fx_proxy(
- target_cls: type[VTTypeAlias],
- tx: "InstructionTranslatorBase",
- proxy: torch.fx.Proxy,
- example_value: Any | None = None,
- subclass_type: type | None = None,
- **options: Any,
- ) -> VTTypeAlias:
- from ..symbolic_convert import InstructionTranslatorBase
- assert isinstance(tx, InstructionTranslatorBase)
- if "guards" in options and options["guards"] is not None:
- tx.output.guards.update(options["guards"])
- assert "example_value" not in proxy.node.meta, f"{proxy.node.meta['example_value']}"
- # See NOTE: [Deferring tensor pack/unpack hooks until runtime]
- with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
- # with preserve_rng_state():
- # only allow_non_graph_fake in this instance because we handle the non-fake
- # cases properly below.
- example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
- # pyrefly: ignore[bad-return]
- return handle_traced_output(
- # type: ignore[arg-type]
- example_value,
- tx,
- proxy,
- options,
- subclass_type,
- target_cls,
- )
- # This handles wrapping of the output of an op traced into the graphs
- def handle_traced_output(
- example_value: Any,
- tx: "InstructionTranslatorBase",
- proxy: torch.fx.Proxy,
- options: dict[str, Any],
- subclass_type: type | None,
- target_cls: type[VTTypeAlias],
- ) -> VariableTracker:
- import torch._functorch.vmap
- import torch._subclasses.fake_tensor
- import torch._utils
- if isinstance(example_value, torch.Tensor):
- # Check if the result is a sparse tensor -
- # We generally don't support sparse tensor so better to graph break here
- if is_sparse_any(example_value) and (
- not tx.export or not config.capture_sparse_compute
- ):
- unimplemented(
- gb_type="Attempted to wrap sparse Tensor with VariableTracker",
- context=str(example_value),
- explanation="torch.compile does not support sparse Tensors with VariableTracker",
- hints=[*graph_break_hints.SPARSE_TENSOR],
- )
- var = construct_tensor_variable(
- target_cls, tx, proxy, example_value, subclass_type, options
- )
- # NOTE: [Side effect tracking for newly constructed tensor]
- # For newly constructed objects that have mutable attributes, we usually
- # construct their VariableTracker via `track_object_new`, but since
- # tensor variable construction is a bit different, we handle them
- # specially here. This ensures that codegen will actually generate the
- # attribute mutations on this tensor.
- #
- # NOTE we pass a dummy object as the `item` argument to avoid
- # constructing a dummy _tensor_ object. The object isn't used for
- # newly constructed VTs anyways.
- assert isinstance(var, VariableTracker)
- tx.output.side_effects._track_obj(
- proxy, var, mutation_type_cls=AttributeMutationNew
- )
- return var
- elif (
- hasattr(proxy.node.target, "__name__")
- and proxy.node.target.__name__ == "set_state"
- # type: ignore[attr-defined]
- and isinstance(proxy.node.target.__self__, torch._C.Generator)
- or proxy.node.target is torch.random.set_rng_state
- ):
- assert type(proxy.node.target) is not str
- # pyrefly: ignore[bad-argument-type]
- return TorchInGraphFunctionVariable(proxy.node.target)
- elif (
- proxy.node.target is torch._C._DisableFuncTorch
- or proxy.node.target is torch.cuda._is_in_bad_fork
- ):
- return UserDefinedObjectVariable(example_value)
- elif istype(example_value, torch.Size) and all(
- isinstance(x, int) for x in example_value
- ):
- sizes = [ConstantVariable.create(x) for x in example_value]
- return SizeVariable(sizes, **options)
- elif isinstance(example_value, (tuple, list)):
- set_example_value(proxy.node, example_value)
- unpacked = []
- for i, val in enumerate(example_value):
- if val is None:
- # nn.MultiheadAttention() can return None, see issue #175
- unpacked.append(
- ConstantVariable.create(None, **options),
- )
- else:
- proxy_i = proxy.tracer.create_proxy(
- kind="call_function",
- target=operator.getitem,
- args=(proxy, i),
- kwargs={},
- )
- if "source" in options:
- # This path should only trigger for list stealing, so it's
- # safe to use `GetItemSource`.
- assert isinstance(example_value, list)
- source = options["source"]
- options_i = options.copy()
- options_i["source"] = GetItemSource(
- base=source,
- index=i,
- index_is_slice=False,
- )
- else:
- # use the same options object as parent
- options_i = options
- # WARNING: this assumes the same target_cls as this tuple/list call
- unpacked.append(
- wrap_fx_proxy_cls(
- # pyrefly: ignore[bad-argument-type]
- target_cls=target_cls,
- tx=tx,
- proxy=proxy_i,
- example_value=val,
- **options_i,
- )
- )
- if isinstance(example_value, torch.Size):
- # NB: Keep the old proxy around. See SizeVariable for an
- # explanation why
- return SizeVariable(unpacked, proxy, **options)
- elif istype(example_value, tuple):
- return TupleVariable(unpacked, **options)
- elif istype(example_value, (list, immutable_list)):
- return ListVariable(unpacked, **options)
- else:
- assert (
- example_value.__class__.__module__ == "torch.return_types"
- or hasattr(example_value, "_fields")
- ), (
- f"expected {example_value.__class__.__module__} == torch.return_types or named tuple but got {type(example_value)}"
- )
- return NamedTupleVariable(unpacked, example_value.__class__, **options) # type: ignore[arg-type]
- elif example_value is None or proxy.node.target is torch.manual_seed:
- return ConstantVariable.create(None, **options)
- elif isinstance(example_value, (torch.SymInt, torch.SymFloat, torch.SymBool)):
- tx.output.current_tracer.track_produced_symints(example_value, proxy)
- set_example_value(proxy.node, example_value)
- return SymNodeVariable.create(tx, proxy, example_value, **options)
- elif (
- isinstance(example_value, torch.Stream)
- and proxy.node.target is get_external_object_by_index
- ) or proxy.node.target in [
- device_interface.current_stream
- for _, device_interface in get_registered_device_interfaces()
- ]:
- set_example_value(proxy.node, example_value)
- index = None
- if proxy.node.target is get_external_object_by_index:
- index = proxy.node.args[0]
- # type: ignore[arg-type]
- return StreamVariable(proxy, example_value, index, **options)
- elif (
- isinstance(example_value, torch.Event)
- and proxy.node.target is get_external_object_by_index
- ) or proxy.node.target in [
- device_interface.current_stream
- for _, device_interface in get_registered_device_interfaces()
- ]:
- index = None
- if proxy.node.target is get_external_object_by_index:
- index = proxy.node.args[0]
- set_example_value(proxy.node, example_value)
- # type: ignore[arg-type]
- return EventVariable(proxy, example_value, index, **options)
- elif (
- inspect.isclass(proxy.node.target)
- and issubclass(proxy.node.target, torch.Event)
- ) or proxy.node.target in [
- device_interface.Event
- for _, device_interface in get_registered_device_interfaces()
- ]:
- set_example_value(proxy.node, example_value)
- return EventVariable(proxy, example_value, None, **options)
- elif proxy.node.target == "query" and proxy.node.op == "call_method":
- set_example_value(proxy.node, example_value)
- return ConstantVariable(example_value, **options)
- elif (
- example_value is not None
- and isinstance(example_value, torch.Event)
- and proxy.node.target == "record_event"
- and proxy.node.op == "call_method"
- ):
- set_example_value(proxy.node, example_value)
- return EventVariable(proxy, example_value, None, **options)
- elif isinstance(example_value, int) and (
- proxy.node.target
- in [
- torch.sym_int,
- getattr,
- operator.getitem,
- torch._utils._element_size,
- torch.seed,
- operator.mod,
- torch._functorch.vmap._validate_and_get_batch_size,
- torch._functorch.predispatch._vmap_increment_nesting,
- torch._functorch.predispatch._vmap_decrement_nesting,
- # some mac builds are missing torch.distributed.get_rank()
- getattr(torch.distributed, "get_rank", _missing),
- getattr(torch.distributed, "get_world_size", _missing),
- # This always wants to be in the graph, even if the constraint
- # results in a constant int
- torch._constrain_as_size,
- ]
- or (
- # TODO: this is a little sus, because we didn't check what the self is
- proxy.node.op == "call_method" and proxy.node.target == "bit_length"
- )
- ):
- set_example_value(proxy.node, example_value)
- return ConstantVariable.create(example_value, **options)
- elif isinstance(example_value, torch.backends.cuda.SDPAParams):
- from .sdpa import SDPAParamsVariable
- set_example_value(proxy.node, example_value)
- return SDPAParamsVariable(proxy, **options)
- elif isinstance(example_value, bool) and (
- proxy.node.target
- in [
- torch._C._are_functorch_transforms_active,
- torch._C._functorch.is_batchedtensor,
- torch.backends.cuda.is_flash_attention_available,
- torch.backends.cuda.can_use_flash_attention,
- torch.backends.cuda.can_use_efficient_attention,
- torch._C._get_cudnn_sdp_enabled,
- torch._C._get_flash_sdp_enabled,
- torch._C._get_mem_efficient_sdp_enabled,
- torch._C._get_math_sdp_enabled,
- torch._C._get_overrideable_sdp_enabled,
- "is_integer",
- ]
- + list(supported_const_comparison_op_values.keys())
- ):
- set_example_value(proxy.node, example_value)
- return ConstantVariable.create(example_value, **options)
- elif isinstance(example_value, (int, float, bool)) and (
- proxy.node.target is call_torchbind
- or proxy.node.target is flat_apply
- or (proxy.node.op == "call_method" and proxy.node.target == "item")
- ):
- set_example_value(proxy.node, example_value)
- return ConstantVariable.create(example_value, **options)
- elif isinstance(example_value, float) or proxy.node.target in ["hex", "__round__"]:
- set_example_value(proxy.node, example_value)
- return ConstantVariable.create(example_value, **options)
- elif is_opaque_type(type(example_value)):
- # This is for handling opaque objects in custom ops
- if is_opaque_value_type(type(example_value)):
- proxy = example_value # pyrefly: ignore[bad-assignment]
- fake_script_obj = torch._library.fake_class_registry.maybe_to_fake_obj(
- tx.output.fake_mode, example_value
- )
- return TorchScriptObjectVariable.create(
- proxy,
- fake_script_obj,
- )
- else:
- unimplemented(
- gb_type="torch.* op returned non-Tensor",
- context=f"example_value type: {typestr(example_value)}; op: {proxy.node.op}; target: {proxy.node.target}",
- explanation="torch.* ops that return a non-Tensor cannot be traced into the Dynamo FX graph output",
- hints=[],
- )
- def infer_subclass_type(value: T) -> type[T] | None:
- if type(value) in (
- torch.Tensor,
- torch.nn.Parameter,
- torch._subclasses.fake_tensor.FakeTensor,
- torch._subclasses.functional_tensor.FunctionalTensor,
- ) or is_traceable_wrapper_subclass(value):
- # Ordinarily, we would fakeify a tensor so that it can get dynamic
- # shapes and be computed on without triggering actual operations.
- # However, how can we fakeify a tensor subclass? Ordinary
- # inheritance (nor multiple inheritance) won't work work.
- #
- # Instead, our plan is to *manually simulate* the tensor subclass
- # inheriting from a fake tensor with dynamo. This means our
- # data representation for a tensor subclass will be a fake tensor
- # + tensor subclass type + any extra data the subclass may have
- # been storing on the tensor. Because all Python accesses are
- # mediated through TensorWithTFOverrideVariable, we can ensure
- # that we dispatch differently, e.g., according to
- # __torch_function__
- #
- # To simplify things for now, the __dict__ tracking bits haven't
- # been implemented yet, but they can be added into this design at
- # a later point in time.
- return None
- else:
- return type(value)
- def get_specialized_props(
- target_cls: Any,
- tx: "InstructionTranslatorBase",
- example_value: Any,
- subclass_type: type | None,
- ) -> dict[str, Any]:
- specialized_props = target_cls.specialize(example_value)
- # TODO: not sure about this fake mode test
- if (
- isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor)
- and example_value.fake_mode is tx.fake_mode
- ):
- if subclass_type:
- tensor_type = subclass_type
- elif isinstance(example_value, torch.nn.Parameter):
- tensor_type = torch.nn.Parameter
- elif isinstance(example_value, torch.nn.Buffer):
- tensor_type = torch.nn.Buffer
- else:
- tensor_type = torch.Tensor
- specialized_props["class_type"] = tensor_type
- return specialized_props
- def construct_tensor_variable(
- target_cls: type[VTTypeAlias],
- tx: "InstructionTranslatorBase",
- proxy: torch.fx.Proxy,
- example_value: Any,
- subclass_type: type | None,
- options: dict[str, Any],
- ) -> VTTypeAlias:
- """
- Actually construct a tensor variable after all the pre-processing from
- wrapping a pre-existing or newly created tensor value.
- """
- # NB: In most (all?) cases, this does not actually do a clone.
- # (WARNING: this means that if we mutate metadata on the fake
- # tensor, the stored example value will update too!)
- example_value = _clone_input(example_value, tx.fake_mode)
- set_example_value(proxy.node, example_value)
- # We bind the unbacked symints in sizes/trdies of tensor lazily.
- # So that subgraphs can access the unbacked symbol's proxy in parent graph
- # when lifting unbacked symbols of input tensors to subgraph inputs.
- # We do it lazily because the tensor may not be used in subgraphs.
- if proxy.node.op != "placeholder":
- tx.output.current_tracer.track_produced_symints(example_value, proxy)
- options.update(get_specialized_props(target_cls, tx, example_value, subclass_type))
- return target_cls(proxy, **options)
- def get_automatic_dynamic_shapes_mark_as() -> DimDynamic:
- if config.automatic_dynamic_shapes_mark_as == "dynamic":
- return DimDynamic.DYNAMIC
- elif config.automatic_dynamic_shapes_mark_as == "unbacked":
- return DimDynamic.UNBACKED
- else:
- raise ValueError(
- f"invalid automatic_dynamic_shapes_mark_as = {config.automatic_dynamic_shapes_mark_as}"
- )
- _DYNAMIC_SOURCES: set[str] | None = None
- _DYNAMIC_SOURCES_CONFIG_HASH: int | None = None
- def get_dynamic_sources() -> set[str]:
- global _DYNAMIC_SOURCES, _DYNAMIC_SOURCES_CONFIG_HASH
- current_hash = hash(torch.compiler.config.dynamic_sources)
- # If we have already calculated the sources and the config hasn't changed, return cached result
- if _DYNAMIC_SOURCES is not None and _DYNAMIC_SOURCES_CONFIG_HASH == current_hash:
- return _DYNAMIC_SOURCES
- # Config has changed or first time, (re)calculate the sources
- _DYNAMIC_SOURCES = {
- s
- for s in torch.compiler.config.dynamic_sources.replace(" ", "").split(",")
- if s
- }
- _DYNAMIC_SOURCES_CONFIG_HASH = current_hash
- return _DYNAMIC_SOURCES
- def is_dynamic_source(source_name: str) -> bool:
- dynamic_sources = get_dynamic_sources()
- for pattern in dynamic_sources:
- if pattern == source_name or re.match(pattern, source_name):
- log.debug(
- "%s was marked dynamic due to dynamic source allowlist pattern: %s",
- source_name,
- pattern,
- )
- return True
- return False
- def record_automatic_dynamic(
- tx: "InstructionTranslatorBase", name: str, e: torch.Tensor
- ) -> FrameStateSizeEntry:
- # This mimics stride inference algorithm in _create_symbolic_sizes_strides_storage_offset
- ex_size = e.size()
- if not is_sparse_any(e):
- ex_stride = e.stride()
- dim = e.dim()
- stride = [None] * dim
- pending = [(ex_stride[i], -i) for i in range(dim)]
- pending.sort(key=_nested_int_aware_sort)
- candidates = {}
- for i_stride, neg_i in pending:
- i = -neg_i
- # pyrefly: ignore [unsupported-operation]
- stride[i] = candidates.get(i_stride, i_stride)
- # pyrefly: ignore [no-matching-overload]
- candidates.setdefault(i_stride * ex_size[i], InferStride(i))
- else:
- # pyrefly: ignore [implicit-any]
- stride = []
- return process_automatic_dynamic(
- # type: ignore[arg-type]ks
- tx,
- name,
- # type: ignore[arg-type]
- FrameStateSizeEntry.make_tensor(tuple(ex_size), tuple(stride)),
- )
- _UNBACKED_SOURCES: set[str] | None = None
- _UNBACKED_SOURCES_CONFIG_HASH: int | None = None
- def get_unbacked_sources() -> set[str]:
- global _UNBACKED_SOURCES, _UNBACKED_SOURCES_CONFIG_HASH
- current_hash = hash(torch.compiler.config.unbacked_sources)
- # If we have already calculated the sources and the config hasn't changed, return cached result
- if _UNBACKED_SOURCES is not None and _UNBACKED_SOURCES_CONFIG_HASH == current_hash:
- return _UNBACKED_SOURCES
- # Config has changed or first time, (re)calculate the sources
- _UNBACKED_SOURCES = {
- s
- for s in torch.compiler.config.unbacked_sources.replace(" ", "").split(",")
- if s
- }
- _UNBACKED_SOURCES_CONFIG_HASH = current_hash
- return _UNBACKED_SOURCES
- def is_unbacked_source(source_name: str) -> bool:
- unbacked_sources = get_unbacked_sources()
- for pattern in unbacked_sources:
- if pattern == source_name or re.match(pattern, source_name):
- log.debug(
- "%s was marked unbacked due to unbacked source allowlist pattern: %s",
- source_name,
- pattern,
- )
- return True
- return False
- # Performs automatic dynamic dim determination.
- # Returns a SymbolicContext
- def _automatic_dynamic(
- e: Any,
- tx: "InstructionTranslatorBase",
- source: Source,
- static_shapes: bool,
- outer_only: bool = False,
- ) -> SymbolicContext:
- # strided NT not supported
- if e.is_nested and not isinstance(
- e, torch.nested._internal.nested_tensor.NestedTensor
- ):
- unimplemented(
- gb_type="Encountered strided NestedTensor in automatic dynamic dim determination",
- context="",
- explanation="torch.compile does not support strided NestedTensor",
- hints=[],
- )
- name = source.name
- prior_policy = tx.output.tracing_context.tensor_to_context.get(e, None)
- shape_env_to_source_to_symbol_cache = (
- prior_policy.shape_env_to_source_to_symbol_cache if prior_policy else {}
- )
- # Get base context if the tensor is a view
- view_base_context: SymbolicContext | None = None
- if e._is_view():
- base_source = AttrSource(source, "_base")
- view_base_context = _automatic_dynamic(e._base, tx, base_source, static_shapes)
- if is_traceable_wrapper_subclass(e) and not outer_only:
- # Get symbolic context for outer tensor
- outer_context = _automatic_dynamic(
- e, tx, source, static_shapes, outer_only=True
- )
- assert isinstance(outer_context, StatefulSymbolicContext)
- # Get symbolic contexts for inner tensors
- inner_contexts = {} # mapping from attr -> symbolic context
- attrs, _ = type(e).__tensor_flatten__(e)
- for attr in attrs:
- inner_tensor = getattr(e, attr)
- inner_source = AttrSource(source, attr)
- inner_contexts[attr] = _automatic_dynamic(
- inner_tensor, tx, inner_source, static_shapes
- )
- return SubclassSymbolicContext(
- dynamic_sizes=outer_context.dynamic_sizes,
- dynamic_strides=outer_context.dynamic_strides,
- constraint_sizes=outer_context.constraint_sizes,
- constraint_strides=outer_context.constraint_strides,
- view_base_context=view_base_context,
- tensor_source=outer_context.tensor_source,
- shape_env_to_source_to_symbol_cache=outer_context.shape_env_to_source_to_symbol_cache,
- inner_contexts=inner_contexts,
- )
- if static_shapes and not is_dynamic_source(name):
- return StatefulSymbolicContext(
- dynamic_sizes=[DimDynamic.STATIC] * e.dim(),
- dynamic_strides=[DimDynamic.INFER_STRIDE] * e.dim(),
- constraint_sizes=[None] * e.dim(),
- constraint_strides=[None] * e.dim(),
- view_base_context=view_base_context,
- tensor_source=source,
- shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache,
- )
- # We preserve the dynamism of inputs. For example, when users call
- # make_fx(torch.cond, tracing_mode="symbolic")(*args), inputs have SymInt sizes.
- from torch.fx.experimental.symbolic_shapes import is_nested_int
- if any(isinstance(s, SymInt) and not is_nested_int(s) for s in e.size()):
- return StatefulSymbolicContext(
- dynamic_sizes=[
- DimDynamic.DYNAMIC if isinstance(s, SymInt) else DimDynamic.STATIC
- for s in e.size()
- ],
- dynamic_strides=[DimDynamic.INFER_STRIDE] * e.dim(),
- constraint_sizes=[None] * e.dim(),
- constraint_strides=[None] * e.dim(),
- view_base_context=view_base_context,
- tensor_source=source,
- shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache,
- )
- # Prep for automatic dynamic
- frame_state_entry = record_automatic_dynamic(tx, name, e)
- # TODO: index export_constraints ahead of time so we don't have to
- # do a linear scan every time here
- t_id = id(e)
- # pyrefly: ignore [implicit-any]
- dim2constraint = {}
- def update_dim2constraint(
- dim: int, constraint_range: "StrictMinMaxConstraint", name: str
- ) -> None:
- if dim in dim2constraint:
- from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
- old_constraint_range, old_name = dim2constraint[dim]
- new_constraint_range = StrictMinMaxConstraint(
- vr=constraint_range.vr & old_constraint_range.vr,
- warn_only=False,
- )
- # It is possible for (non-None) old_name and name to be different
- # but this will only happen the corresponding Dims can be derived equal.
- new_name = old_name or name
- dim2constraint[dim] = new_constraint_range, new_name
- else:
- dim2constraint[dim] = constraint_range, name
- from torch.export.dynamic_shapes import _RelaxedConstraint
- if tx.output.export_constraints is not None:
- # type: ignore[iterable]
- for constraint in tx.output.export_constraints:
- if isinstance(constraint, _RelaxedConstraint):
- continue
- if constraint.t_id == t_id:
- update_dim2constraint(
- constraint.dim, constraint.constraint_range, constraint.name
- )
- dynamic_sizes = []
- dynamic_strides = []
- constraint_sizes = []
- constraint_strides = []
- specialize_on = []
- for i in range(e.dim()):
- # NB: mark dynamic has precedence over static
- marked_strict_unbacked = i in getattr(
- e, "_dynamo_strict_unbacked_indices", set()
- )
- marked_unbacked = i in getattr(e, "_dynamo_unbacked_indices", set())
- marked_dynamic = i in getattr(e, "_dynamo_dynamic_indices", set())
- marked_weak_dynamic = i in getattr(e, "_dynamo_weak_dynamic_indices", set())
- marked_static = i in getattr(e, "_dynamo_static_indices", set())
- specialize_on.append(getattr(e, "_specialize_on", {}).get(i, []))
- # Reflect the user directive in the frame_state
- # For dynamic, apply None always
- normalized_source_name = normalize_source_name(source.name)
- base_source = source
- if isinstance(base_source, ChainedSource):
- base_source = base_source.get_base()
- if marked_dynamic or (
- isinstance(base_source, LocalSource)
- and base_source.dynamism is not None
- # pyrefly: ignore[no-matching-overload]
- and dict(base_source.dynamism).get(normalized_source_name, {i: False})[i]
- ):
- # TODO: This can be batched
- # TODO: Doing this here is kind of sus, maybe better to set this
- # up when we initially created the FrameStateSizeEntry to bong
- # into the mutable state
- log.debug("automatic dynamic %s marked dynamic", name)
- mark_size = [auto_unset] * e.dim()
- # pyrefly: ignore [unsupported-operation]
- mark_size[i] = auto_dynamic
- # pyrefly: ignore [bad-argument-type]
- frame_state_entry |= FrameStateSizeEntry.make_size(size=mark_size)
- # NB: both static and dynamic have precedence over
- automatic_dynamic_size = (
- config.automatic_dynamic_shapes and frame_state_entry.is_size_dynamic(i)
- )
- # NB: previously, if size was dynamic, we wouldn't make its stride
- # dynamic. But now, because of InferStride concept, we will properly
- # not make stride dynamic even if it's wobbling
- automatic_dynamic_stride = (
- config.automatic_dynamic_shapes and frame_state_entry.is_stride_dynamic(i)
- )
- if is_dynamic_source(name):
- log.debug("%s marked dynamic via source whitelist", name)
- automatic_dynamic_size = True
- if is_unbacked_source(name):
- log.debug("%s marked unbacked via source whitelist", name)
- automatic_dynamic_size = True
- automatic_dynamic = automatic_dynamic_size or automatic_dynamic_stride
- # We will process constraints first, as they will imply that we
- # have a dynamic dimension
- # Precedence: export constraints > eager constraints
- constraint = dim2constraint.get(i)
- if constraint is None:
- constraint_size = None
- constraint_stride = None
- if marked_dynamic and not config.allow_ignore_mark_dynamic:
- # constraint_stride is deliberaly kept None because no easy way to provide value ranges for mark dynamic
- constraint_stride = None
- if hasattr(e, "_dynamo_dynamic_range"):
- dim_range = [
- dr for dr in e._dynamo_dynamic_range if dr.dim == i
- ].pop()
- if dim_range.min is None and dim_range.max is None:
- constraint_size = RelaxedUnspecConstraint(warn_only=False)
- else:
- from torch.fx.experimental.symbolic_shapes import (
- StrictMinMaxConstraint,
- )
- constraint_size = StrictMinMaxConstraint(
- vr=ValueRanges(lower=dim_range.min, upper=dim_range.max),
- warn_only=False,
- )
- else:
- constraint_size = RelaxedUnspecConstraint(warn_only=False)
- elif marked_strict_unbacked:
- constraint_size = RelaxedUnspecConstraint(warn_only=False)
- elif not marked_static and automatic_dynamic:
- set_feature_use("dynamo.automatic_dynamic_shapes", True)
- if automatic_dynamic_size:
- constraint_size = RelaxedUnspecConstraint(warn_only=True)
- if automatic_dynamic_stride:
- constraint_stride = RelaxedUnspecConstraint(warn_only=True)
- else:
- if not marked_static and not config.automatic_dynamic_shapes:
- set_feature_use("dynamo.automatic_dynamic_shapes", False)
- constraint_size = None
- constraint_stride = None
- else:
- constraint_size, name_ = constraint
- constraint_stride = None
- dim_name = f"{name}.size()[{i}]"
- tx.output.shape_env.source_name_to_debug_name[dim_name] = name_
- constraint_sizes.append(constraint_size)
- constraint_strides.append(constraint_stride)
- if marked_unbacked or is_unbacked_source(name):
- dynamic_size = DimDynamic.UNBACKED
- elif (
- constraint_size is not None
- or marked_dynamic
- or marked_weak_dynamic
- or is_nested_int(e.size()[i])
- ):
- # NB: We could assert static_shapes is False here, but it
- # seems better to allow the user to override symbolic_context in this
- # case
- if automatic_dynamic:
- dynamic_size = get_automatic_dynamic_shapes_mark_as()
- else:
- dynamic_size = DimDynamic.DYNAMIC
- elif static_shapes or config.assume_static_by_default or marked_static:
- dynamic_size = DimDynamic.STATIC
- else:
- # TODO: When does this show up?
- dynamic_size = DimDynamic.DUCK
- if constraint_stride is not None:
- dynamic_stride = DimDynamic.DYNAMIC
- else:
- dynamic_stride = DimDynamic.INFER_STRIDE
- dynamic_sizes.append(dynamic_size)
- dynamic_strides.append(dynamic_stride)
- return StatefulSymbolicContext(
- dynamic_sizes=dynamic_sizes,
- dynamic_strides=dynamic_strides,
- constraint_sizes=constraint_sizes,
- # pyrefly: ignore [bad-argument-type]
- constraint_strides=constraint_strides,
- specialize_on=specialize_on,
- view_base_context=view_base_context,
- tensor_source=source,
- shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache,
- shape_ids=getattr(e, "_dynamo_shape_ids", None),
- )
- # See note [Tensor Fakification and Symbol Caching]
- def wrap_to_fake_tensor_and_record(
- e: Any,
- tx: "InstructionTranslatorBase",
- *,
- source: Source | None,
- is_tensor: bool,
- parent_context: Any | None = None,
- ) -> Any:
- _t0 = time.time_ns()
- try:
- return _wrap_to_fake_tensor_and_record_impl(
- e, tx, source=source, is_tensor=is_tensor, parent_context=parent_context
- )
- finally:
- tx.output.bytecode_tracing_timings.wrap_to_fake_tensor_and_record_ns += (
- time.time_ns() - _t0
- )
- def _wrap_to_fake_tensor_and_record_impl(
- e: Any,
- tx: "InstructionTranslatorBase",
- *,
- source: Source | None,
- is_tensor: bool,
- parent_context: Any | None = None,
- ) -> Any:
- if (
- type(e) in (torch.Tensor, torch.nn.Parameter, FakeTensor)
- or isinstance(e, torch.Tensor)
- or is_traceable_wrapper_subclass(e)
- ):
- assert source is not None
- static_shapes, _reason = tensor_always_has_static_shape(
- e,
- is_tensor,
- tensor_source=source,
- )
- if not parent_context:
- symbolic_context = _automatic_dynamic(e, tx, source, static_shapes)
- else:
- # Parent contexts are passed in when we are recursively creating
- # fake tensors for subclasses. A better design would be not to create a
- # parent/child relationship, but to recursively call _automatic_dynamic
- # as we recursively call wrap_to_fake_tensor_and_record. This runs
- # into bugs around how meta_utils knows and works to create fake tensors
- # with tensor subclasses. Ideally, dynamo would drive both the recursive
- # wrap_to_fake_tensor_and_record and _automatic_dynamic policy creation.
- assert isinstance(source, AttrSource)
- inner_context_name = source.member
- symbolic_context = parent_context.inner_contexts[inner_context_name]
- log.debug(
- "wrap_to_fake %s %s %s %s",
- source.name,
- tuple(e.shape),
- symbolic_context,
- type(e),
- )
- # Note [enable_python_dispatcher in dynamo]
- # Dynamo disables itself when it runs fake tensor prop, which means that tensor subclasses
- # have no way to know (purely based off of global state) if they are currently being run under compile or not.
- # we use enable_python_dispatcher mainly to tweak the DispatchKeyState so that subclass authors
- # can check it to know if they are running in an eager context or not
- with enable_python_dispatcher():
- assert tx.fake_mode is not None
- fake_e = wrap_fake_exception(
- lambda: tx.fake_mode.from_tensor(
- e, # type: ignore[arg-type]
- source=source,
- symbolic_context=symbolic_context,
- )
- )
- if (
- source is not None
- and isinstance(fake_e, FakeTensor)
- and (sym_val := fake_e.item_memo) is not None
- ):
- tx.output.tracked_fakes.append(
- TrackedFake(sym_val, CallMethodItemSource(source), symbolic_context)
- )
- if is_traceable_wrapper_subclass(fake_e):
- attrs, _ = fake_e.__tensor_flatten__()
- for attr in attrs:
- fake_inner = getattr(fake_e, attr)
- inner = getattr(e, attr)
- inner_source = AttrSource(source, attr)
- wrap_to_fake_tensor_and_record(
- inner,
- tx,
- source=inner_source,
- is_tensor=isinstance(fake_inner, torch.Tensor),
- parent_context=symbolic_context,
- )
- tx.output.tracing_context.tensor_to_context[e] = symbolic_context
- if is_sparse_any(fake_e):
- # TODO: for TensorGuards, this eventually may need more
- # fields for the size/stride of any other constituents
- values = fake_e._values() if fake_e.is_sparse else fake_e.values()
- tx.output.input_source_to_sizes_strides[source] = {
- "size": fake_e.size(),
- # TODO: revise this, but for now this stride instead of ()
- # avoids SegFault with PYTORCH_TEST_WITH_DYNAMO=1
- "stride": (1,) * fake_e.ndim,
- "values_size": values.size(),
- "values_stride": values.stride(),
- }
- else:
- tx.output.input_source_to_sizes_strides[source] = {
- "size": fake_e.size(),
- "stride": fake_e.stride(),
- }
- if (
- is_tensor
- and not (static_shapes and source.is_specialized_nn_module())
- and not is_constant_source(source)
- ):
- tx.output.tracked_fakes.append(
- TrackedFake(fake_e, source, symbolic_context)
- )
- tx.output.tracked_fakes_id_to_source[id(e)].append(source)
- return fake_e
- else:
- return e
- class SourcelessBuilder:
- """
- Like builder, but stateless and does not require a source. Useful for simple type->VT objects, or objects
- that are being created/evaporated during inlining (ex: consider a locally made list of tensors we then iterate over
- .), such a list should not show up as an artifact from inputs, nor in reconstruction, nor in the graph. However,
- there may be reasons to represent it as a ListVariable internally.
- NOTE - Objects produced here are born UNGUARDED due to the nature of sources!
- NOTE - This class is very new! It will have some rough edges, but it was created to stem the bleeding of giant
- if/else type->VariableTracker trees that were cropping up all over dynamo.
- """
- def __init__(self) -> None:
- raise AssertionError("Use SourcelessBuilder.create()")
- @overload
- @staticmethod
- def create(
- tx: "InstructionTranslatorBase",
- value: type[set[Any]]
- | type[dict[Any, Any]]
- | type[tuple[Any, ...]]
- | type[list[Any]],
- ) -> BuiltinVariable: ...
- @overload
- @staticmethod
- def create(tx: "InstructionTranslatorBase", value: list[Any]) -> ListVariable: ...
- @overload
- @staticmethod
- def create(
- tx: "InstructionTranslatorBase", value: tuple[Any, ...]
- ) -> TupleVariable: ...
- @overload
- @staticmethod
- def create(
- tx: "InstructionTranslatorBase", value: bool | int | float | str
- ) -> ConstantVariable: ...
- @overload
- @staticmethod
- def create(tx: "InstructionTranslatorBase", value: Any) -> VariableTracker: ...
- @staticmethod
- def create(tx: "InstructionTranslatorBase", value: Any) -> VariableTracker:
- value_type = type(value)
- # type: ignore[attr-defined]
- fast_handler = SourcelessBuilder._type_handlers.get(value_type)
- if fast_handler:
- return fast_handler(tx, value)
- if isinstance(value, VariableTracker):
- # This is always valid to call, and useful for recursive calls.
- return value
- elif is_opaque_value_type(type(value)):
- # This is for handling opaque objects in custom ops
- fake_script_obj = torch._library.fake_class_registry.maybe_to_fake_obj(
- tx.output.fake_mode, value
- )
- return TorchScriptObjectVariable.create(
- value,
- fake_script_obj,
- )
- # type: ignore[attr-defined]
- elif isinstance(value, dataclasses._HAS_DEFAULT_FACTORY_CLASS):
- return UserDefinedObjectVariable(value)
- elif ConstantVariable.is_literal(value):
- return ConstantVariable.create(value)
- elif callable(value) and trace_rules.lookup_callable(value) is not None:
- if trace_rules.is_callable_allowed(value):
- tx.output.has_user_defined_allowed_in_graph = True
- # pyrefly: ignore[not-callable, bad-argument-count]
- return trace_rules.lookup_callable(value)(value)
- elif callable(value) and UserDefinedClassVariable.is_supported_new_method(
- value
- ):
- # NamedTuple._make uses an alias of tuple.__new__
- # pyrefly: ignore[not-callable, bad-argument-count, missing-attribute]
- obj = trace_rules.lookup_callable(value.__self__)(value.__self__)
- return GetAttrVariable(obj, "__new__")
- elif is_function_or_wrapper(value):
- # pyrefly: ignore[not-callable, bad-argument-count]
- return trace_rules.lookup(value)(value)
- elif isinstance(
- value, (enum.Enum, torch.DispatchKey, torch._C._functorch.TransformType)
- ):
- return EnumVariable(value)
- elif isinstance(value, (type, abc.ABCMeta)):
- if isinstance(value, type) and issubclass(value, enum.Enum):
- return UserDefinedEnumClassVariable(value)
- return UserDefinedClassVariable(value)
- elif isinstance(value, types.MethodWrapperType):
- return MethodWrapperVariable(value)
- elif (
- isinstance(value, types.MethodType)
- # We only want to support sourceless class objects here
- # An instance variable is not allowed and it should have source
- and isinstance(value.__self__, (type, abc.ABCMeta))
- ):
- # value is a classmethod
- assert getattr(value.__self__, value.__func__.__name__) == value
- cls_obj_vt = SourcelessBuilder.create(tx, value.__self__)
- try:
- # pyrefly: ignore[bad-argument-type]
- return cls_obj_vt.var_getattr(tx, value.__func__.__name__)
- except NotImplementedError:
- pass # failthrough to unimplemented branch
- elif isinstance(value, torch.fx.graph_module.GraphModule):
- return SourcelessGraphModuleVariable(value)
- elif isinstance(value, torch.utils._pytree.TreeSpec):
- return UserDefinedObjectVariable(value)
- elif PlacementVariable.is_placement(value):
- return PlacementVariable(value)
- elif DeviceMeshVariable.is_device_mesh(value):
- return DeviceMeshVariable(value)
- elif value is functools.wraps:
- return FunctoolsWrapsVariable(value)
- elif isinstance(value, re.Pattern):
- return ConstantLikeVariable(value)
- elif isinstance(value, torch._dynamo.variables.lazy.LazySymNodeFormatString):
- return ConstantVariable.create(str(value))
- elif isinstance(value, type(torch._higher_order_ops.flex_attention_backward)):
- return torch._dynamo.variables.higher_order_ops.FlexAttentionBackwardHighOrderVariable(
- value
- )
- elif isinstance(value, (types.GenericAlias, types.UnionType)):
- return TypingVariable(value)
- elif is_namedtuple(value):
- output = [
- SourcelessBuilder.create(tx, getattr(value, name))
- for name in namedtuple_fields(type(value))
- ]
- return NamedTupleVariable(output, tuple_cls=type(value))
- elif (
- isinstance(value, torch.SymInt)
- and value.node.expr in tx.output.bound_symbols
- ):
- proxy = tx.output.bound_symbols[value.node.expr]
- return SymNodeVariable.create(tx, proxy)
- elif istype(value, object):
- return ObjectVariable(value)
- unimplemented(
- gb_type="Unexpected type in sourceless builder",
- context=f"{value_type.__module__}.{value_type.__qualname__}",
- explanation=f"SourcelessBuilder.create does not know how to wrap {value_type}",
- hints=[*graph_break_hints.DYNAMO_BUG],
- )
- @staticmethod
- def wrap_constant_literal(value: object) -> VariableTracker:
- assert ConstantVariable.is_literal(value)
- return ConstantVariable.create(value=value)
- @staticmethod
- def make_type_handlers() -> dict[
- type, Callable[["InstructionTranslator", Any], VariableTracker]
- ]:
- create = SourcelessBuilder.create
- handlers: dict[
- type, Callable[[InstructionTranslator, Any], VariableTracker]
- ] = {}
- for t in common_constant_types:
- handlers[t] = lambda tx, value: ConstantVariable(value)
- handlers[set] = lambda tx, value: SetVariable(
- [create(tx, x) for x in value], mutation_type=ValueMutationNew()
- )
- handlers[OrderedSet] = lambda tx, value: OrderedSetVariable(
- [create(tx, x) for x in value], mutation_type=ValueMutationNew()
- )
- handlers[dict] = lambda tx, value: ConstDictVariable(
- {create(tx, k): create(tx, v) for k, v in value.items()},
- type(value),
- mutation_type=ValueMutationNew(),
- )
- handlers[list] = lambda tx, value: ListVariable(
- [create(tx, x) for x in value], mutation_type=ValueMutationNew()
- )
- handlers[tuple] = lambda tx, value: TupleVariable(
- [create(tx, x) for x in value]
- )
- handlers[torch.Size] = lambda tx, value: SizeVariable(
- [create(tx, x) for x in value]
- )
- handlers[collections.OrderedDict] = handlers[dict]
- handlers[immutable_dict] = handlers[dict]
- handlers[immutable_list] = handlers[list]
- # Sourceless MappingProxyType object can be encountered while tracing
- # type.__dict__["__dict__"].__get__
- handlers[types.MappingProxyType] = lambda tx, value: MappingProxyVariable(
- ConstDictVariable(
- {create(tx, k): create(tx, v) for k, v in value.items()},
- dict,
- mutation_type=ValueMutationNew(),
- ),
- )
- handlers[types.GetSetDescriptorType] = (
- lambda tx, value: GetSetDescriptorVariable(value)
- )
- handlers[inspect.Parameter] = lambda tx, value: UserDefinedObjectVariable(
- value, mutation_type=ValueMutationNew()
- )
- handlers[random.Random] = lambda tx, value: RandomClassVariable()
- handlers[types.ModuleType] = lambda tx, value: PythonModuleVariable(value)
- handlers[torch.DispatchKeySet] = lambda tx, value: DispatchKeySetVariable(
- value, mutation_type=ValueMutationNew()
- )
- handlers[torch._functorch.pyfunctorch.FuncTorchInterpreter] = (
- lambda tx, value: FuncTorchInterpreterVariable(
- value, mutation_type=ValueMutationNew()
- )
- )
- handlers[torch.distributions.constraints._Real] = (
- lambda tx, value: UserDefinedObjectVariable(
- value, mutation_type=ValueMutationNew()
- )
- )
- handlers[torch.distributions.constraints._Interval] = (
- lambda tx, value: UserDefinedObjectVariable(
- value, mutation_type=ValueMutationNew()
- )
- )
- handlers[torch.distributions.constraints.Constraint] = (
- lambda tx, value: UserDefinedObjectVariable(
- value, mutation_type=ValueMutationNew()
- )
- )
- def passthrough(tx: "InstructionTranslator", value: T) -> T:
- return value
- for cls in VariableTrackerMeta.all_subclasses:
- handlers[cls] = passthrough
- return handlers
- SourcelessBuilder._type_handlers = SourcelessBuilder.make_type_handlers()
- class SourcelessUserDefinedObjectBuilder:
- """
- SourceLessBuilder does not return a UserDefinedObjectVariable, but in some
- cases it might be ok to return UserDefinedObjects. In such case, use this
- builder.
- """
- def __init__(self) -> None:
- raise AssertionError("Use SourcelessUserDefinedObjectBuilder.create()")
- @staticmethod
- def create(tx: "InstructionTranslator", value: Any) -> VariableTracker:
- value_type = type(value)
- if issubclass(value_type, MutableMapping):
- return MutableMappingVariable(value, mutation_type=ValueMutationNew())
- elif isinstance(value, torch.nn.Module):
- return UnspecializedNNModuleVariable(
- value, mutation_type=ValueMutationNew()
- )
- else:
- return UserDefinedObjectVariable(value, mutation_type=ValueMutationNew())
|