| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349 |
- """
- Built-in function and type variable tracking for TorchDynamo's symbolic execution.
- This module contains variable tracker classes for Python built-in functions, types,
- and operations during graph compilation. It handles symbolic execution of:
- - Built-in functions (len, getattr, isinstance, etc.)
- - Type constructors (int, float, str, list, dict, etc.)
- - Built-in operators and methods
- - Special Python constructs (super, hasattr, etc.)
- Key classes:
- - BuiltinVariable: Tracks built-in functions and handles their execution
- - TypeVariable: Manages type constructor calls and type checking
- - SuperVariable: Handles super() calls in class hierarchies
- These variable trackers ensure that built-in Python operations are correctly
- handled during symbolic execution, either by executing them directly when safe
- or by creating appropriate graph nodes when needed.
- """
- import contextlib
- import functools
- import inspect
- import itertools
- import logging
- import math
- import operator
- import sys
- import types
- import typing
- import unittest
- from collections import defaultdict, OrderedDict
- from collections.abc import Callable, Iterable, KeysView, Sequence
- from typing import Any, cast, Literal, TYPE_CHECKING, Union
- import torch
- from torch import sym_float, sym_int
- from torch._subclasses.meta_utils import is_sparse_any
- from torch.overrides import BaseTorchFunctionMode
- from torch.utils._python_dispatch import is_traceable_wrapper_subclass
- from .. import config, graph_break_hints, polyfills, variables
- from ..exc import (
- ObservedAttributeError,
- ObservedUserStopIteration,
- raise_observed_exception,
- unimplemented,
- Unsupported,
- UserError,
- UserErrorType,
- )
- from ..guards import GuardBuilder, install_guard
- from ..replay_record import DummyModule
- from ..source import (
- AttrSource,
- GetItemSource,
- GlobalSource,
- is_constant_source,
- Source,
- TypeSource,
- )
- from ..utils import (
- check_constant_args,
- check_numpy_ndarray_args,
- check_unspec_or_constant_args,
- check_unspec_python_args,
- cmp_name_to_op_mapping,
- dict_methods,
- extract_fake_example_value,
- frozenset_methods,
- get_fake_value,
- guard_if_dyn,
- is_tensor_getset_descriptor,
- is_wrapper_or_member_descriptor,
- istype,
- numpy_operator_wrapper,
- proxy_args_kwargs,
- raise_args_mismatch,
- set_methods,
- str_methods,
- tensortype_to_dtype,
- )
- from .base import AsPythonConstantNotImplementedError, ValueMutationNew, VariableTracker
- from .constant import CONSTANT_VARIABLE_NONE, ConstantVariable, EnumVariable
- from .dicts import (
- ConstDictVariable,
- DefaultDictVariable,
- DictKeysVariable,
- DictViewVariable,
- FrozensetVariable,
- is_hashable,
- OrderedSetClassVariable,
- SetVariable,
- )
- from .lists import (
- BaseListVariable,
- ListIteratorVariable,
- ListVariable,
- RangeVariable,
- SizeVariable,
- TupleIteratorVariable,
- TupleVariable,
- )
- from .streams import EventVariable, StreamVariable
- from .tensor import (
- FakeItemVariable,
- supported_comparison_ops,
- SymNodeVariable,
- TensorVariable,
- UnspecializedPythonVariable,
- )
- from .user_defined import (
- MutableMappingVariable,
- UserDefinedDictVariable,
- UserDefinedObjectVariable,
- UserDefinedVariable,
- )
- if TYPE_CHECKING:
- # Cyclic dependency...
- from torch._dynamo.codegen import PyCodegen
- from torch._dynamo.symbolic_convert import InstructionTranslator
- log = logging.getLogger(__name__)
- IN_PLACE_DESUGARING_MAP = {
- operator.iadd: operator.add,
- operator.isub: operator.sub,
- operator.imul: operator.mul,
- operator.ifloordiv: operator.floordiv,
- operator.itruediv: operator.truediv,
- operator.imod: operator.mod,
- operator.imatmul: operator.imatmul,
- operator.ilshift: operator.lshift,
- operator.irshift: operator.rshift,
- operator.ipow: operator.pow,
- operator.iand: operator.and_,
- operator.ior: operator.or_,
- operator.ixor: operator.xor,
- }
- _HandlerCallback = Callable[
- ["InstructionTranslator", typing.Any, typing.Any], VariableTracker | None
- ]
- _TrackersType = Union[type[VariableTracker], tuple[type[VariableTracker], ...]]
- polyfill_fn_mapping = {
- operator.eq: polyfills.cmp_eq,
- operator.ne: polyfills.cmp_ne,
- operator.lt: polyfills.cmp_lt,
- operator.le: polyfills.cmp_le,
- operator.gt: polyfills.cmp_gt,
- operator.ge: polyfills.cmp_ge,
- }
- bin_ops = (
- operator.pow,
- operator.mul,
- operator.matmul,
- operator.floordiv,
- operator.truediv,
- operator.mod,
- operator.add,
- operator.lt,
- operator.gt,
- operator.ge,
- operator.le,
- operator.ne,
- operator.eq,
- operator.sub,
- operator.ipow,
- operator.imul,
- operator.imatmul,
- operator.ifloordiv,
- operator.itruediv,
- operator.imod,
- operator.iadd,
- operator.isub,
- )
- bin_int_ops = (
- operator.and_,
- operator.or_,
- operator.xor,
- operator.iand,
- operator.ixor,
- operator.ior,
- )
- un_int_ops = (operator.invert,)
- tensor_and_int_ops = (
- operator.lshift,
- operator.rshift,
- operator.ilshift,
- operator.irshift,
- operator.getitem,
- )
- un_ops = (
- operator.abs,
- operator.pos,
- operator.neg,
- operator.not_, # Note: this has a local scalar dense call
- operator.length_hint,
- )
- BUILTIN_TO_TENSOR_FN_MAP: dict[Callable[..., Any], Callable[..., Any]] = {}
- # These functions represent the r* versions of the above ops
- # Basically, if __add__(1, Tensor) is called, it is translated
- # to __radd__(Tensor, 1).
- # In the builtin var, we check if there is a tensor in the first args position,
- # if not, we swap the args and use the r* version of the op.
- BUILTIN_TO_TENSOR_RFN_MAP: dict[Callable[..., Any], Callable[..., Any]] = {}
- def populate_builtin_to_tensor_fn_map() -> None:
- global BUILTIN_TO_TENSOR_FN_MAP
- if len(BUILTIN_TO_TENSOR_FN_MAP) > 0:
- # Only populate once; after there are elements present no need to
- # repopulate
- return
- most_recent_func: Callable[..., Any] | None = None
- class GetMethodMode(BaseTorchFunctionMode):
- """
- Mode to extract the correct methods from torch function invocations
- (Used to get the correct torch.Tensor methods from builtins)
- """
- def __torch_function__(
- self,
- func: Callable[..., Any],
- types: Any,
- args: Sequence[Any] = (),
- kwargs: dict[str, Any] | None = None,
- ) -> Any:
- kwargs = kwargs or {}
- nonlocal most_recent_func
- most_recent_func = func
- return func(*args, **kwargs)
- inp0 = torch.ones(1)
- inp1 = torch.ones(1)
- inp0_int = torch.ones(1, dtype=torch.int32)
- inp1_int = torch.ones(1, dtype=torch.int32)
- with GetMethodMode():
- setups_and_oplists: list[tuple[Callable[..., Any], Iterable[Any]]] = [
- (lambda o: o(inp0), un_ops),
- (lambda o: o(inp0_int), un_int_ops),
- (lambda o: o(inp0, inp1), bin_ops),
- (lambda o: o(inp0_int, inp1_int), bin_int_ops),
- (lambda o: o(inp0_int, 0), tensor_and_int_ops),
- ]
- for setup_fn, op_list in setups_and_oplists:
- for op in op_list:
- setup_fn(op)
- assert most_recent_func is not None
- BUILTIN_TO_TENSOR_FN_MAP[op] = most_recent_func
- # gather the reverse functions
- rsetups_and_oplists: list[tuple[Callable[..., Any], Iterable[Any]]] = [
- (
- lambda o: o(1, inp1),
- bin_ops,
- ), # Get r* ops, (ex. __sub__(int, Tensor) -> __rsub__(Tensor, int))
- (lambda o: o(1, inp1_int), bin_int_ops),
- (lambda o: o(0, inp0_int), tensor_and_int_ops),
- ]
- rskips = {operator.matmul, operator.imatmul, operator.getitem}
- for setup_fn, op_list in rsetups_and_oplists:
- for op in op_list:
- if op in rskips:
- continue
- setup_fn(op)
- assert most_recent_func is not None
- if most_recent_func != BUILTIN_TO_TENSOR_FN_MAP[op]:
- BUILTIN_TO_TENSOR_RFN_MAP[op] = most_recent_func
- class BuiltinVariable(VariableTracker):
- """
- A VariableTracker that represents a built-in value (functions and operators).
- A lot of the code here assumes it will be a function object.
- The BuiltinVariable class wraps Python built-in functions (like len, isinstance, etc.)
- and operators (like +, -, *, etc.) to enable symbolic execution during tracing. This allows
- Dynamo to properly handle these operations when converting Python code to FX graphs while
- maintaining correct semantics and enabling optimizations.
- """
- _SENTINEL = object()
- _nonvar_fields = {
- "fn",
- *VariableTracker._nonvar_fields,
- }
- @classmethod
- def create_with_source(cls, value: Any, source: Source) -> "BuiltinVariable":
- install_guard(source.make_guard(GuardBuilder.BUILTIN_MATCH))
- return cls(value, source=source)
- @staticmethod
- @functools.cache
- def _constant_fold_functions() -> set[Callable[..., Any]]:
- fns: set[Callable[..., Any]] = {
- abs,
- all,
- any,
- bool,
- callable,
- chr,
- complex,
- divmod,
- float,
- getattr,
- int,
- len,
- max,
- min,
- ord,
- pow,
- repr,
- round,
- str,
- str.format,
- sum,
- type,
- operator.abs,
- operator.pos,
- operator.neg,
- operator.not_,
- operator.truth,
- operator.invert,
- operator.pow,
- operator.mul,
- operator.matmul,
- operator.floordiv,
- operator.truediv,
- operator.mod,
- operator.add,
- operator.sub,
- operator.getitem,
- operator.length_hint,
- operator.lshift,
- operator.rshift,
- operator.and_,
- operator.or_,
- operator.xor,
- operator.ipow,
- operator.imul,
- operator.imatmul,
- operator.ifloordiv,
- operator.itruediv,
- operator.imod,
- operator.iadd,
- operator.isub,
- operator.ilshift,
- operator.irshift,
- operator.iand,
- operator.ixor,
- operator.ior,
- operator.index,
- }
- from .tensor import supported_comparison_ops
- fns.update(supported_comparison_ops.values())
- fns.update(x for x in math.__dict__.values() if isinstance(x, type(math.sqrt)))
- return fns
- def can_constant_fold_through(self) -> bool:
- return self.fn in self._constant_fold_functions()
- @staticmethod
- @functools.cache
- def _fx_graph_functions() -> set[Callable[..., Any]]:
- fns = {
- operator.abs,
- operator.pos,
- operator.neg,
- operator.not_,
- operator.invert,
- operator.pow,
- operator.mul,
- operator.matmul,
- operator.floordiv,
- operator.truediv,
- operator.mod,
- operator.add,
- operator.lt,
- operator.gt,
- operator.ge,
- operator.le,
- operator.ne,
- operator.eq,
- operator.sub,
- operator.length_hint,
- operator.lshift,
- operator.rshift,
- operator.and_,
- operator.or_,
- operator.xor,
- operator.ipow,
- operator.imul,
- operator.imatmul,
- operator.ifloordiv,
- operator.itruediv,
- operator.getitem,
- operator.imod,
- operator.iadd,
- operator.isub,
- operator.ilshift,
- operator.irshift,
- operator.iand,
- operator.ixor,
- operator.ior,
- }
- return fns # type: ignore[return-value]
- @staticmethod
- @functools.cache
- def _binops() -> dict[
- Callable[..., object], tuple[list[str], Callable[..., object]]
- ]:
- # function -> ([forward name, reverse name, in-place name], in-place op)
- fns: dict[Callable[..., object], tuple[list[str], Callable[..., object]]] = {
- operator.add: (["__add__", "__radd__", "__iadd__"], operator.iadd),
- operator.sub: (["__sub__", "__rsub__", "__isub__"], operator.isub),
- operator.mul: (["__mul__", "__rmul__", "__imul__"], operator.imul),
- operator.truediv: (
- ["__truediv__", "__rtruediv__", "__itruediv__"],
- operator.itruediv,
- ),
- operator.floordiv: (
- ["__floordiv__", "__rfloordiv__", "__ifloordiv__"],
- operator.ifloordiv,
- ),
- operator.mod: (["__mod__", "__rmod__", "__imod__"], operator.imod),
- pow: (["__pow__", "__rpow__", "__ipow__"], operator.ipow),
- operator.pow: (["__pow__", "__rpow__", "__ipow__"], operator.ipow),
- operator.lshift: (
- ["__lshift__", "__rlshift__", "__ilshift__"],
- operator.ilshift,
- ),
- operator.rshift: (
- ["__rshift__", "__rrshift__", "__irshift__"],
- operator.irshift,
- ),
- operator.xor: (["__xor__", "__rxor__", "__ixor__"], operator.xor),
- # NB: The follow binary operators are not supported for now, since the
- # corresponding magic methods aren't defined on SymInt / SymFloat:
- # operator.matmul
- # divmod
- # operator.and_
- # operator.or_
- }
- return fns
- @staticmethod
- @functools.cache
- def _binop_handlers() -> dict[
- Callable[..., object],
- list[
- tuple[
- tuple[
- type[VariableTracker],
- _TrackersType,
- ],
- _HandlerCallback,
- ]
- ],
- ]:
- # Multiple dispatch mechanism defining custom binop behavior for certain type
- # combinations. Handlers are attempted in order, and will be used if the type checks
- # match. They are expected to have the signature:
- # fn(tx, arg0: VariableTracker, arg1: VariableTracker) -> VariableTracker
- from .functions import BaseUserFunctionVariable, UserFunctionVariable
- from .nn_module import NNModuleVariable
- from .tensor import supported_const_comparison_ops
- from .torch import BaseTorchVariable
- from .user_defined import (
- UserDefinedClassVariable,
- UserDefinedObjectVariable,
- UserDefinedVariable,
- )
- # Override table contains: op_fn -> [list of handlers]
- op_handlers: dict[Any, list[Any]] = {}
- for (
- op,
- (magic_method_names, in_place_op),
- ) in BuiltinVariable._binops().items():
- op_handlers[op] = []
- op_handlers[in_place_op] = []
- forward_name, reverse_name, inplace_name = magic_method_names
- # User-defined args (highest precedence)
- def user_defined_handler(
- tx: "InstructionTranslator",
- a: VariableTracker,
- b: VariableTracker,
- *,
- forward_name: str = forward_name,
- reverse_name: str = reverse_name,
- ) -> VariableTracker:
- # Manually handle reversing logic if needed (e.g. call __radd__)
- # TODO: If we expand this to handle tensor args, we need to manually
- # handle cases like this:
- #
- # class A(int):
- # def __radd__(self, other):
- # print("woof")
- # torch.randn(3) + A(3)
- #
- # In this example, A.__radd__() is not called -> nothing is printed, because
- # Tensor.__add__ only does a subtype test against int, ignoring the subclass.
- # To be fully correct, we should not call A.__radd__() here, and there may be
- # other cases to reason about and add exceptions for.
- if isinstance(a, UserDefinedVariable):
- return a.call_method(tx, forward_name, [b], {})
- else:
- return b.call_method(tx, reverse_name, [a], {})
- op_handlers[op].append(
- ((UserDefinedVariable, VariableTracker), user_defined_handler)
- )
- op_handlers[op].append(
- ((VariableTracker, UserDefinedVariable), user_defined_handler)
- )
- def user_defined_inplace_handler(
- tx: "InstructionTranslator",
- a: VariableTracker,
- b: VariableTracker,
- *,
- forward_name: str = inplace_name,
- ) -> VariableTracker:
- return a.call_method(tx, forward_name, [b], {})
- op_handlers[in_place_op].append(
- ((UserDefinedVariable, VariableTracker), user_defined_inplace_handler)
- )
- op_handlers[in_place_op].append(
- ((VariableTracker, UserDefinedVariable), user_defined_inplace_handler)
- )
- # Dynamic shape args
- def dynamic_handler(
- tx: "InstructionTranslator",
- a: VariableTracker,
- b: VariableTracker,
- *,
- fn: Callable[..., Any] = op,
- ) -> VariableTracker:
- from .builder import wrap_fx_proxy
- return wrap_fx_proxy(
- tx,
- tx.output.create_proxy(
- "call_function", fn, *proxy_args_kwargs([a, b], {})
- ),
- )
- op_handlers[op].append(
- ((SymNodeVariable, VariableTracker), dynamic_handler)
- )
- op_handlers[op].append(
- ((VariableTracker, SymNodeVariable), dynamic_handler)
- )
- # NB: Prefer out-of-place op when calling in-place op to generate valid graph
- op_handlers[in_place_op].append(
- ((SymNodeVariable, VariableTracker), dynamic_handler)
- )
- op_handlers[in_place_op].append(
- ((VariableTracker, SymNodeVariable), dynamic_handler)
- )
- # Special cases - lower precedence but still prefer these over constant folding
- # List-like addition (e.g. [1, 2] + [3, 4])
- def tuple_add_handler(
- tx: "InstructionTranslator", a: BaseListVariable, b: VariableTracker
- ) -> VariableTracker:
- return TupleVariable([*a.items, *b.unpack_var_sequence(tx)])
- def size_add_handler(
- tx: "InstructionTranslator", a: BaseListVariable, b: VariableTracker
- ) -> VariableTracker:
- return SizeVariable([*a.items, *b.unpack_var_sequence(tx)])
- list_like_addition_handlers: list[
- tuple[
- tuple[
- type[VariableTracker],
- _TrackersType,
- ],
- _HandlerCallback,
- ]
- ] = [
- # NB: Prefer the tuple-specific logic over base logic because of
- # some SizeVariable weirdness. Specifically, the tuple-specific logic
- # drops the subclass type (e.g. SizeVariable) and returns TupleVariables.
- (
- (SizeVariable, SizeVariable),
- size_add_handler,
- ),
- (
- (SizeVariable, TupleVariable),
- size_add_handler,
- ),
- (
- (TupleVariable, SizeVariable),
- size_add_handler,
- ),
- (
- (TupleVariable, TupleVariable),
- tuple_add_handler,
- ),
- (
- (TupleVariable, ConstantVariable),
- tuple_add_handler,
- ),
- (
- (ConstantVariable, TupleVariable),
- lambda tx, a, b: TupleVariable(
- [
- *a.unpack_var_sequence(tx),
- *b.items,
- ],
- ),
- ),
- (
- (
- ListVariable,
- (BaseListVariable, ConstantVariable, ListIteratorVariable),
- ),
- lambda tx, a, b: ListVariable(
- [*a.items, *b.unpack_var_sequence(tx)],
- mutation_type=ValueMutationNew(),
- ),
- ),
- (
- (BaseListVariable, BaseListVariable),
- lambda tx, a, b: type(a)(
- [
- *a.items,
- *b.items,
- ]
- ),
- ),
- ]
- op_handlers[operator.add].extend(list_like_addition_handlers)
- def list_iadd_handler(
- tx: "InstructionTranslator", a: BaseListVariable, b: VariableTracker
- ) -> Any:
- if a.is_immutable() or not b.has_unpack_var_sequence(tx):
- # Handler doesn't apply
- return None
- seq = b.unpack_var_sequence(tx)
- tx.output.side_effects.mutation(a)
- a.items.extend(seq)
- return a
- list_like_iadd_handlers: list[Any] = [
- (
- (ListVariable, VariableTracker),
- list_iadd_handler,
- ),
- (
- (TupleVariable, TupleVariable),
- tuple_add_handler,
- ),
- (
- (TupleVariable, ConstantVariable),
- tuple_add_handler,
- ),
- ]
- op_handlers[operator.iadd].extend(list_like_iadd_handlers)
- # List-like expansion (e.g. [1, 2, 3] * 3)
- def expand_list_like(
- tx: "InstructionTranslator", lst: VariableTracker, const: VariableTracker
- ) -> VariableTracker:
- if not isinstance(lst, BaseListVariable) and lst.is_python_constant():
- lst, const = const, lst
- try:
- assert isinstance(lst, BaseListVariable)
- return lst.__class__(
- items=lst.items * const.as_python_constant(),
- mutation_type=ValueMutationNew(),
- )
- except MemoryError as exc:
- raise_observed_exception(
- type(exc),
- tx,
- args=list(map(ConstantVariable.create, exc.args)),
- )
- list_like_expansion_handlers: list[
- tuple[
- tuple[type[VariableTracker], type[VariableTracker]],
- _HandlerCallback,
- ]
- ] = [
- ((ListVariable, ConstantVariable), expand_list_like),
- ((TupleVariable, ConstantVariable), expand_list_like),
- ((ConstantVariable, ListVariable), expand_list_like),
- ((ConstantVariable, TupleVariable), expand_list_like),
- ]
- op_handlers[operator.mul].extend(list_like_expansion_handlers)
- def create_cmp_op_handlers(
- op: Callable[..., Any],
- ) -> list[tuple[tuple[_TrackersType, _TrackersType], _HandlerCallback]]:
- def compare_by_value(
- tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker
- ) -> VariableTracker:
- try:
- return ConstantVariable(op(a.value, b.value)) # type: ignore[attr-defined]
- except TypeError as exc:
- raise_observed_exception(
- type(exc),
- tx,
- args=list(map(ConstantVariable.create, exc.args)),
- )
- result: list[
- tuple[
- tuple[
- _TrackersType,
- _TrackersType,
- ],
- _HandlerCallback,
- ]
- ] = [((ConstantVariable, ConstantVariable), compare_by_value)]
- if op in polyfill_fn_mapping:
- # For constants, speedup the comparison instead of using
- # polyfill. Removing this line causes major regression for pr
- # time benchmark - add_loop_eager.
- result = [
- ((ConstantVariable, ConstantVariable), compare_by_value),
- ((EnumVariable, EnumVariable), compare_by_value),
- ]
- op_var = BuiltinVariable(op)
- # Special handling of SymNode variable
- result.extend(
- [
- (
- (SymNodeVariable, VariableTracker),
- op_var._comparison_with_symnode,
- ),
- (
- (VariableTracker, SymNodeVariable),
- op_var._comparison_with_symnode,
- ),
- ]
- )
- def handler(
- tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker
- ) -> VariableTracker:
- return tx.inline_user_function_return(
- VariableTracker.build(tx, polyfill_fn_mapping[op]), [a, b], {}
- )
- result.append(((VariableTracker, VariableTracker), handler))
- return result
- result = [((ConstantVariable, ConstantVariable), compare_by_value)]
- if op in supported_const_comparison_ops.values() and op.__name__.startswith(
- "is_"
- ):
- # Tensor is None, List is not None, etc
- none_result = op(object(), None)
- def never(
- tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker
- ) -> VariableTracker:
- return ConstantVariable(none_result)
- obj_op_none = never
- none_op_obj = never
- types_that_are_never_none = (
- TensorVariable,
- SymNodeVariable,
- NNModuleVariable,
- BaseListVariable,
- UserDefinedVariable,
- BaseUserFunctionVariable,
- ConstDictVariable,
- BaseTorchVariable,
- )
- result.extend(
- [
- (
- (types_that_are_never_none, ConstantVariable),
- obj_op_none,
- ),
- (
- (ConstantVariable, types_that_are_never_none),
- none_op_obj,
- ),
- ]
- )
- op_var = BuiltinVariable(op)
- result.extend(
- [
- (
- (
- (UserFunctionVariable, BuiltinVariable),
- (UserFunctionVariable, BuiltinVariable),
- ),
- lambda tx, a, b: ConstantVariable(op(a.fn, b.fn)),
- ),
- (
- (
- NNModuleVariable,
- NNModuleVariable,
- ),
- lambda tx, a, b: ConstantVariable(
- op(
- tx.output.get_submodule(a.module_key),
- tx.output.get_submodule(b.module_key),
- )
- ),
- ),
- (
- (UserDefinedObjectVariable, UserDefinedObjectVariable),
- compare_by_value,
- ),
- (
- (UserDefinedClassVariable, UserDefinedClassVariable),
- compare_by_value,
- ),
- (
- (
- (StreamVariable, EventVariable, ConstantVariable),
- (StreamVariable, EventVariable, ConstantVariable),
- ),
- compare_by_value,
- ),
- (
- (TensorVariable, VariableTracker),
- op_var._comparison_with_tensor,
- ),
- (
- (VariableTracker, TensorVariable),
- op_var._comparison_with_tensor,
- ),
- (
- (SymNodeVariable, VariableTracker),
- op_var._comparison_with_symnode,
- ),
- (
- (VariableTracker, SymNodeVariable),
- op_var._comparison_with_symnode,
- ),
- ]
- )
- def handle_is(
- tx: "InstructionTranslator",
- left: VariableTracker,
- right: VariableTracker,
- ) -> VariableTracker | None:
- # If the two objects are of different type, we can safely return False
- # and True for `is` and `is not`, respectively
- if type(left) is not type(right):
- return ConstantVariable.create(op.__name__ != "is_")
- if left is right:
- return ConstantVariable.create(op(left, right))
- if istype(left, variables.ObjectVariable) and istype(
- right, variables.ObjectVariable
- ):
- return ConstantVariable.create(op(left.value, right.value))
- if (
- istype(left, variables.ExceptionVariable)
- and istype(right, variables.ExceptionVariable)
- and left.exc_type is not right.exc_type
- ):
- return ConstantVariable.create(op(left, right))
- result.append(((VariableTracker, VariableTracker), handle_is)) # type: ignore[arg-type]
- return result
- for op in supported_comparison_ops.values():
- assert callable(op)
- assert op not in op_handlers
- op_handlers[op] = create_cmp_op_handlers(op)
- return op_handlers
- @staticmethod
- def _find_binop_handler(
- op: Callable[..., Any], a_type: type[VariableTracker], b_type: type
- ) -> list[_HandlerCallback] | None:
- handlers = BuiltinVariable._binop_handlers().get(op)
- if handlers is None:
- return None
- matches = []
- for (type1, type2), handler in handlers:
- if issubclass(a_type, type1) and issubclass(b_type, type2):
- matches.append(handler)
- return matches
- def can_insert_in_graph(self) -> bool:
- return self.fn in self._fx_graph_functions()
- def __init__(self, fn: Any, **kwargs: Any) -> None:
- super().__init__(**kwargs)
- self.fn = fn
- def __repr__(self) -> str:
- if self.fn is None:
- name = "None"
- else:
- name = self.fn.__name__
- return f"{self.__class__.__name__}({name})"
- def as_python_constant(self) -> Any:
- return self.fn
- def as_proxy(self) -> Any:
- DTYPE = {
- bool: torch.bool,
- int: torch.int64,
- float: torch.float64,
- }
- if self.fn in DTYPE:
- return DTYPE[self.fn]
- return super().as_proxy()
- def reconstruct(self, codegen: "PyCodegen") -> None:
- name = self.fn.__name__
- assert self.fn.__module__ == "builtins"
- assert name not in codegen.tx.f_globals, "shadowed global"
- codegen.append_output(codegen.create_load_global(name, add=True))
- def constant_args(self, *args: VariableTracker, **kwargs: VariableTracker) -> bool:
- return check_constant_args(args, kwargs)
- def tensor_args(self, *args: VariableTracker) -> bool:
- any_tensor = False
- for arg in args:
- if isinstance(arg, variables.GetAttrVariable):
- return False
- any_tensor = any_tensor or arg.is_tensor()
- return any_tensor
- def tensor_args_type(self, arg_types: list[type]) -> bool:
- any_tensor = False
- for arg_type in arg_types:
- if issubclass(arg_type, variables.GetAttrVariable):
- return False
- any_tensor = any_tensor or issubclass(arg_type, variables.TensorVariable)
- return any_tensor
- def python_and_tensor_constant_only(
- self, *args: VariableTracker, **kwargs: VariableTracker
- ) -> bool:
- tensor_args = []
- non_tensor_args = []
- for i in itertools.chain(args, kwargs.values()):
- if i.is_tensor():
- tensor_args.append(i)
- else:
- non_tensor_args.append(i)
- return all(
- is_constant_source(t.source) if t.source is not None else False
- for t in tensor_args
- ) and self.constant_args(*non_tensor_args)
- @staticmethod
- def unwrap_unspec_args_kwargs(
- args: Sequence[VariableTracker], kwargs: dict[str, VariableTracker]
- ) -> tuple[list[Any], dict[str, Any]]:
- return [x.as_python_constant() for x in args], {
- k: v.as_python_constant() for k, v in kwargs.items()
- }
- def has_constant_handler(
- self, args: Sequence[VariableTracker], kwargs: dict[str, VariableTracker]
- ) -> bool:
- return self.can_constant_fold_through() and check_unspec_or_constant_args(
- args, kwargs
- )
- @staticmethod
- def _make_handler(
- fn: Callable[..., Any], arg_types: list[type], has_kwargs: bool
- ) -> Callable[
- [
- "InstructionTranslator",
- tuple[VariableTracker, ...],
- dict[str, VariableTracker],
- ],
- VariableTracker | None,
- ]:
- from .lazy import LazyVariableTracker
- obj = BuiltinVariable(fn)
- handlers: list[_HandlerCallback] = []
- if any(issubclass(t, LazyVariableTracker) for t in arg_types):
- return lambda tx, args, kwargs: obj.call_function(
- tx, [v.realize() for v in args], kwargs
- )
- if inspect.isclass(fn) and (
- issubclass(fn, BaseException)
- # GeneratorExit doesn't inherit from Exception
- # >>> issubclass(GeneratorExit, Exception)
- # False
- or fn is GeneratorExit
- ):
- def create_exception_class_object(
- tx: "InstructionTranslator",
- args: tuple[VariableTracker, ...],
- kwargs: dict[str, VariableTracker],
- ) -> VariableTracker:
- if fn is AssertionError and not all(
- x.is_python_constant() and isinstance(x.as_python_constant(), str)
- for x in args
- ):
- unimplemented(
- gb_type="assert with non-string message",
- context=str(args),
- explanation="Dynamo only supports asserts with string messages",
- hints=[*graph_break_hints.SUPPORTABLE],
- )
- return variables.ExceptionVariable(fn, args, kwargs)
- return create_exception_class_object
- if obj.can_insert_in_graph() and not (
- fn is operator.getitem
- and not issubclass(arg_types[0], variables.TensorVariable)
- ):
- if obj.tensor_args_type(arg_types):
- return obj._handle_insert_op_in_graph
- elif has_kwargs:
- # need runtime check for kwargs
- handlers.append(obj._handle_insert_op_in_graph)
- # Handle binary ops (e.g. __add__ / __radd__, __iadd__, etc.)
- # NB: Tensor args are handled above and not here
- if len(arg_types) == 2 and not has_kwargs:
- # Try to find a handler for the arg types; otherwise, fall through to constant handler
- binop_handlers = BuiltinVariable._find_binop_handler(fn, *arg_types)
- if not binop_handlers:
- pass
- elif len(binop_handlers) == 1:
- (binop_handler,) = binop_handlers
- handlers.append(lambda tx, args, _: binop_handler(tx, *args))
- else:
- def call_binop_handlers(
- tx: "InstructionTranslator", args: Any, _: Any
- ) -> Any:
- # pyrefly: ignore [not-iterable]
- for fn in binop_handlers:
- rv = fn(tx, *args)
- if rv:
- return rv
- return None
- handlers.append(call_binop_handlers)
- self_handler = getattr(obj, f"call_{fn.__name__}", None)
- if self_handler:
- def call_self_handler(
- tx: "InstructionTranslator",
- args: Sequence[VariableTracker],
- kwargs: dict[str, VariableTracker],
- ) -> VariableTracker | None:
- try:
- # pyrefly: ignore [not-callable]
- return self_handler(tx, *args, **kwargs)
- except TypeError:
- # Check if binding is bad. inspect signature bind is expensive.
- # So check only when handler call fails.
- try:
- # pyrefly: ignore [bad-argument-type]
- inspect.signature(self_handler).bind(tx, *args, **kwargs)
- except TypeError as e:
- has_constant_handler = obj.has_constant_handler(args, kwargs)
- if not has_constant_handler:
- log.warning( # noqa: G200
- "incorrect arg count %s %s and no constant handler",
- self_handler,
- e,
- )
- unimplemented(
- gb_type="invalid call to builtin op handler",
- context=f"invalid args to {self_handler}: {args} {kwargs}",
- explanation=f"Encountered TypeError when trying to handle op {fn.__name__}",
- hints=[*graph_break_hints.DIFFICULT],
- )
- else:
- raise
- except Unsupported as exc:
- has_constant_handler = obj.has_constant_handler(args, kwargs)
- if not has_constant_handler:
- raise
- # Actually, we will handle this just fine
- exc.remove_from_stats()
- return None
- handlers.append(call_self_handler)
- if obj.can_constant_fold_through():
- if (
- all(issubclass(x, ConstantVariable) for x in arg_types)
- and not has_kwargs
- ):
- def constant_fold_handler(
- tx: "InstructionTranslator",
- args: Sequence[VariableTracker],
- kwargs: dict[str, VariableTracker],
- ) -> VariableTracker | None:
- # fast path
- try:
- res = fn(
- *[x.as_python_constant() for x in args],
- )
- except Exception as exc:
- raise_observed_exception(
- type(exc),
- tx,
- args=list(map(ConstantVariable.create, exc.args)),
- )
- except AsPythonConstantNotImplementedError as exc:
- unimplemented(
- gb_type="constant fold exception",
- context=f"attempted to run function {fn} with arguments {args}",
- explanation="Encountered exception when attempting to constant fold.",
- hints=[*graph_break_hints.DYNAMO_BUG],
- from_exc=exc,
- )
- # pyrefly: ignore [unbound-name]
- return VariableTracker.build(tx, res)
- else:
- def constant_fold_handler(
- tx: "InstructionTranslator",
- args: Sequence[VariableTracker],
- kwargs: dict[str, VariableTracker],
- ) -> VariableTracker | None:
- # path with a runtime check
- if check_unspec_or_constant_args(args, kwargs):
- try:
- res = fn(
- *[x.as_python_constant() for x in args],
- **{
- k: v.as_python_constant() for k, v in kwargs.items()
- },
- )
- except AsPythonConstantNotImplementedError as exc:
- unimplemented(
- gb_type="constant fold exception",
- context=f"attempted to run function {fn} with arguments {args}",
- explanation="Encountered exception when attempting to constant fold.",
- hints=[*graph_break_hints.DYNAMO_BUG],
- from_exc=exc,
- )
- except Exception as exc:
- raise_observed_exception(
- type(exc),
- tx,
- args=list(map(ConstantVariable.create, exc.args)),
- )
- # pyrefly: ignore [unbound-name]
- return VariableTracker.build(tx, res)
- return None
- handlers.append(constant_fold_handler)
- def call_unimplemented(args: Sequence[VariableTracker]) -> None:
- real_arg_types = [arg.python_type_name() for arg in args]
- unimplemented(
- gb_type="Failed to trace builtin operator",
- context=f"builtin {fn.__name__} {arg_types} {has_kwargs}",
- explanation=f"Dynamo does not know how to trace builtin operator `{fn.__name__}` "
- f"with argument types {real_arg_types} (has_kwargs {has_kwargs})",
- hints=[
- f"Avoid calling builtin `{fn.__name__}` with argument types {real_arg_types}. "
- f"Consider using an equivalent alternative function/method to `{fn.__name__}`.",
- "If you are attempting to call a logging function (e.g. `print`), "
- "you can try adding it to `torch._dynamo.config.reorderable_logging_functions`.",
- "Please report an issue to PyTorch.",
- ],
- )
- if len(handlers) == 0:
- return lambda tx, args, kwargs: call_unimplemented(args)
- elif len(handlers) == 1:
- (handler,) = handlers
- def builtin_dispatch(
- tx: "InstructionTranslator",
- args: Sequence[VariableTracker],
- kwargs: dict[str, VariableTracker],
- ) -> VariableTracker | None:
- rv = handler(tx, args, kwargs)
- if rv:
- return rv
- call_unimplemented(args)
- return rv
- else:
- def builtin_dispatch(
- tx: "InstructionTranslator",
- args: Sequence[VariableTracker],
- kwargs: dict[str, VariableTracker],
- ) -> VariableTracker | None:
- rv = None
- for fn in handlers:
- rv = fn(tx, args, kwargs)
- if rv:
- return rv
- call_unimplemented(args)
- return rv
- return builtin_dispatch
- def call_vars(self, tx: "InstructionTranslator", *args: Any) -> VariableTracker:
- if len(args) == 0:
- unimplemented(
- gb_type="unimplemented builtin op vars() with no arguments",
- context=f"vars: {self} {args}",
- explanation=f"Dynamo does not know how to trace builtin operator {self.fn} with no arguments",
- hints=[*graph_break_hints.SUPPORTABLE],
- )
- assert len(args) == 1
- # vars(obj) is obj.__dict__ if __dict__ is present else TypeError
- try:
- return args[0].var_getattr(tx, "__dict__")
- except ObservedAttributeError:
- raise_observed_exception(TypeError, tx)
- def _handle_insert_op_in_graph(
- self,
- tx: "InstructionTranslator",
- args: Sequence[VariableTracker],
- kwargs: dict[str, VariableTracker],
- ) -> VariableTracker | None:
- from .builder import wrap_fx_proxy, wrap_fx_proxy_cls
- if kwargs and not self.tensor_args(*args, *kwargs.values()):
- return None
- # insert handling for torch function here
- from .builder import SourcelessBuilder
- from .torch_function import can_dispatch_torch_function, dispatch_torch_function
- global BUILTIN_TO_TENSOR_RFN_MAP, BUILTIN_TO_TENSOR_FN_MAP
- if can_dispatch_torch_function(tx, args, kwargs):
- # Only remap the fn to tensor methods if we aren't exporting
- # export serde does not handle method descriptors today
- if not tx.export:
- # Ensure the builtin maps are populated before accessing them
- populate_builtin_to_tensor_fn_map()
- # Use sourceless builder, we built the map ourselves
- if not args[0].is_tensor():
- if self.fn in BUILTIN_TO_TENSOR_RFN_MAP:
- func = BUILTIN_TO_TENSOR_RFN_MAP[self.fn]
- else:
- func = BUILTIN_TO_TENSOR_FN_MAP[self.fn]
- tmp = args[0]
- # swap args and call reverse version of func
- args[0] = args[1] # type: ignore[index]
- args[1] = tmp # type: ignore[index]
- else:
- func = BUILTIN_TO_TENSOR_FN_MAP[self.fn]
- else:
- func = self.fn
- fn_var = SourcelessBuilder.create(tx, func)
- return dispatch_torch_function(tx, fn_var, args, kwargs)
- fn = self.fn
- try:
- # Constant fold for constant tensor and python constants
- if self.python_and_tensor_constant_only(*args, **kwargs):
- from ..bytecode_transformation import unique_id
- from .functions import invoke_and_store_as_constant
- return invoke_and_store_as_constant(
- tx, fn, unique_id(fn.__name__), args, kwargs
- )
- if fn in IN_PLACE_DESUGARING_MAP and isinstance(
- args[0], variables.ConstantVariable
- ):
- # In-place operators like += usually mustate tensor
- # values, but in the edge case of immutable values they
- # re-bind the variable.
- #
- # The easiest way to keep the graph consistent in this
- # scenario is to de-sugar eagerly.
- fn = IN_PLACE_DESUGARING_MAP[fn]
- args = [args[0], args[1]] # type: ignore[assignment]
- if fn is operator.getitem and isinstance(args[1], SymNodeVariable):
- # Standard indexing will force specialization due to
- # __index__. Rewrite as a regular torch op which will
- # trace fine
- fn = torch.select
- args = [
- args[0],
- variables.ConstantVariable.create(0),
- args[1],
- ] # type: ignore[assignment]
- # Interaction between ndarray and tensors:
- # We prefer the tensor op whenever there are tensors involved
- # NB: Use exact type check here - NumpyNdarrayVariable is a TensorVariable
- # subclass but should NOT trigger the tensor path
- if check_numpy_ndarray_args(args, kwargs) and not any(
- type(arg) is TensorVariable for arg in args
- ):
- proxy = tx.output.create_proxy(
- "call_function",
- numpy_operator_wrapper(fn),
- *proxy_args_kwargs(args, kwargs),
- )
- return wrap_fx_proxy_cls(variables.NumpyNdarrayVariable, tx, proxy)
- if fn is operator.eq and len(args) == 2 and args[0].is_tensor():
- # Dynamo expects `__eq__` str while operator.eq gives just `eq`
- # TODO - supporting all comparison operators could also work but
- # it fails lots of tests because graph str changes.
- return args[0].call_method(tx, "__eq__", list(args[1:]), kwargs)
- proxy = tx.output.create_proxy(
- "call_function",
- fn,
- *proxy_args_kwargs(args, kwargs),
- )
- if any(isinstance(arg, FakeItemVariable) for arg in args):
- return wrap_fx_proxy_cls(
- FakeItemVariable,
- tx,
- proxy,
- )
- elif check_unspec_python_args(args, kwargs):
- _args, _kwargs = self.unwrap_unspec_args_kwargs(args, kwargs)
- raw_value = fn(*_args, **_kwargs)
- need_unwrap = any(
- x.need_unwrap
- for x in itertools.chain(args, kwargs.values())
- if isinstance(x, variables.UnspecializedPythonVariable)
- )
- return wrap_fx_proxy_cls(
- UnspecializedPythonVariable,
- tx,
- proxy,
- raw_value=raw_value,
- need_unwrap=need_unwrap,
- )
- elif all(isinstance(x, SymNodeVariable) for x in args):
- return SymNodeVariable.create(tx, proxy, None)
- else:
- # Work around for vision_maskrcnn due to precision difference
- # specialize the dividend when float divide by tensor
- if fn is operator.truediv and isinstance(
- args[0], variables.UnspecializedPythonVariable
- ):
- args = list(args)
- args[0] = args[0].as_python_constant()
- return wrap_fx_proxy(tx, proxy)
- except NotImplementedError:
- unimplemented(
- gb_type="unimplemented builtin op on tensor arguments",
- context=f"partial tensor op: {self} {args} {kwargs}",
- explanation=f"Dynamo does not know how to trace builtin operator {self.fn} with tensor arguments",
- hints=[*graph_break_hints.SUPPORTABLE],
- )
- call_function_handler_cache: dict[
- tuple[object, ...],
- Callable[
- [
- "InstructionTranslator",
- Sequence[VariableTracker],
- dict[str, VariableTracker],
- ],
- VariableTracker,
- ],
- ] = {}
- def call_function(
- self,
- tx: "InstructionTranslator",
- args: Sequence[VariableTracker],
- kwargs: dict[str, VariableTracker],
- ) -> VariableTracker:
- key: tuple[object, ...]
- if kwargs:
- kwargs = {k: v.realize() for k, v in kwargs.items()}
- key = (self.fn, *(type(x) for x in args), True)
- else:
- key = (self.fn, *(type(x) for x in args))
- handler = self.call_function_handler_cache.get(key)
- if not handler:
- self.call_function_handler_cache[key] = handler = self._make_handler( # type: ignore[assignment]
- self.fn, [type(x) for x in args], bool(kwargs)
- )
- assert handler is not None
- return handler(tx, args, kwargs) # type: ignore[return-value]
- def call_method(
- self,
- tx: "InstructionTranslator",
- name: str,
- args: list[VariableTracker],
- kwargs: dict[str, VariableTracker],
- ) -> VariableTracker:
- if self.fn is object and name == "__setattr__":
- assert len(args) == 3
- assert len(kwargs) == 0
- obj, name_var, val = args
- obj = obj.realize()
- if (
- isinstance(obj, UserDefinedObjectVariable)
- and tx.output.side_effects.is_attribute_mutation(obj)
- and name_var.is_python_constant()
- ):
- return obj.method_setattr_standard(tx, name_var, val)
- if name == "__new__":
- # Supported __new__ methods
- if self.fn is object and len(args) == 1:
- assert len(kwargs) == 0
- return tx.output.side_effects.track_new_user_defined_object(
- self, args[0], args[1:]
- )
- if self.fn is dict and len(args) == 1 and not kwargs:
- dict_vt = ConstDictVariable({}, dict, mutation_type=ValueMutationNew())
- if isinstance(args[0], BuiltinVariable) and args[0].fn is dict:
- return dict_vt
- # We don't have to set the underlying dict_vt in
- # UserDefinedDictVariable because it will be set to empty
- # ConstDictVariableTracker in the constructor.
- return tx.output.side_effects.track_new_user_defined_object(
- self,
- args[0],
- args[1:],
- )
- if (
- self.fn is tuple
- and len(args) == 2
- and args[1].has_force_unpack_var_sequence(tx)
- and not kwargs
- ):
- if isinstance(args[0], BuiltinVariable) and args[0].fn is tuple:
- init_args = args[1].force_unpack_var_sequence(tx)
- return variables.TupleVariable(
- init_args, mutation_type=ValueMutationNew()
- )
- return tx.output.side_effects.track_new_user_defined_object(
- self,
- args[0],
- args[1:],
- )
- if self.fn is list:
- list_vt = ListVariable([], mutation_type=ValueMutationNew())
- if isinstance(args[0], BuiltinVariable) and args[0].fn is list:
- return list_vt
- return tx.output.side_effects.track_new_user_defined_object(
- self,
- args[0],
- args[1:],
- )
- if (
- self.fn in (float, complex)
- and len(args) == 1
- and (
- (self.fn is float and name in ("fromhex", "hex"))
- or (name == "from_number" and sys.version_info >= (3, 14))
- )
- ):
- if args[0].is_python_constant():
- try:
- fn = getattr(self.fn, name)
- res = fn(args[0].as_python_constant())
- return variables.ConstantVariable.create(res)
- except (OverflowError, ValueError) as e:
- raise_observed_exception(
- type(e),
- tx,
- args=list(map(ConstantVariable.create, e.args)),
- )
- if self.fn is object and name == "__init__":
- # object.__init__ is a no-op
- return variables.CONSTANT_VARIABLE_NONE
- if self.fn is dict and name == "fromkeys":
- return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs)
- if self.fn is dict:
- resolved_fn = getattr(self.fn, name)
- if resolved_fn in dict_methods:
- if isinstance(args[0], variables.UserDefinedDictVariable):
- return args[0]._dict_vt.call_method(tx, name, args[1:], kwargs)
- elif isinstance(args[0], variables.ConstDictVariable):
- return args[0].call_method(tx, name, args[1:], kwargs)
- if self.fn is set:
- resolved_fn = getattr(self.fn, name)
- if resolved_fn in set_methods:
- if isinstance(args[0], variables.UserDefinedSetVariable):
- return args[0]._set_vt.call_method(tx, name, args[1:], kwargs)
- elif isinstance(args[0], variables.SetVariable):
- return args[0].call_method(tx, name, args[1:], kwargs)
- if self.fn is frozenset:
- resolved_fn = getattr(self.fn, name)
- if resolved_fn in frozenset_methods:
- if isinstance(args[0], variables.FrozensetVariable):
- return args[0].call_method(tx, name, args[1:], kwargs)
- if self.fn is str and len(args) >= 1:
- resolved_fn = getattr(self.fn, name)
- if resolved_fn in str_methods:
- # Only delegate to ConstantVariable, not other types that happen to be constants
- if isinstance(args[0], ConstantVariable):
- return args[0].call_method(tx, name, args[1:], kwargs)
- if self.fn is float and len(args) >= 1:
- # Only delegate to ConstantVariable, not other types that happen to be constants
- if isinstance(args[0], ConstantVariable):
- return ConstantVariable.create(
- getattr(float, name)(args[0].as_python_constant())
- )
- return super().call_method(tx, name, args, kwargs)
- def _call_int_float(
- self, tx: "InstructionTranslator", arg: VariableTracker
- ) -> VariableTracker | None:
- # Handle cases like int(torch.seed())
- # Also handle sym_float to sym_int cases
- if arg.is_tensor() or isinstance(arg, SymNodeVariable):
- if arg.is_tensor():
- item = arg.call_method(tx, "item", [], {})
- else:
- item = arg
- fn_ = sym_int if self.fn is int else sym_float
- from torch._dynamo.variables.builder import wrap_fx_proxy
- return wrap_fx_proxy(
- tx=tx,
- proxy=tx.output.create_proxy(
- "call_function",
- fn_,
- (item.as_proxy(),),
- {},
- ),
- )
- return None
- call_int = _call_int_float
- call_float = _call_int_float
- def call_bool(
- self, tx: "InstructionTranslator", arg: VariableTracker
- ) -> VariableTracker | None:
- if arg.is_tensor():
- item = arg.call_method(tx, "item", [], {})
- if isinstance(item, SymNodeVariable) and isinstance(
- item.sym_num, torch.SymBool
- ):
- return item
- if isinstance(item, variables.ConstantVariable):
- return variables.ConstantVariable.create(bool(item.value))
- return SymNodeVariable.create(tx, item.as_proxy() != 0)
- # Emulate `PyBool_Type.tp_vectorcall` which boils down to `PyObject_IsTrue`.
- # https://github.com/python/cpython/blob/3.12/Objects/object.c#L1674-L1697
- if isinstance(arg, SymNodeVariable):
- # Note that we delay specializing on symbolic values to avoid
- # unnecessary guards. Specialization will happen later if, e.g., the
- # resulting boolean is used for branching.
- if isinstance(arg.sym_num, torch.SymBool):
- return arg
- # Emulate `nb_bool` of int/float objects
- # - https://github.com/python/cpython/blob/3.12/Objects/longobject.c#L4940-L4944
- # - https://github.com/python/cpython/blob/3.12/Objects/floatobject.c#L878-L882
- assert istype(arg.sym_num, (torch.SymInt, torch.SymFloat))
- return SymNodeVariable.create(tx, arg.as_proxy() != 0)
- # TODO handle more cases and merge this with this with `generic_jump`.
- return None
- def call_repr(
- self, tx: "InstructionTranslator", arg: VariableTracker
- ) -> VariableTracker | None:
- """Handle repr() on user defined objects."""
- if isinstance(arg, variables.UserDefinedObjectVariable):
- repr_method = arg.value.__repr__
- if type(arg.value).__repr__ is object.__repr__:
- # Default repr - build and trace it
- fn_vt = VariableTracker.build(tx, repr_method)
- return fn_vt.call_function(tx, [], {})
- elif is_wrapper_or_member_descriptor(repr_method):
- unimplemented(
- gb_type="Attempted to call repr() method implemented in C/C++",
- context="",
- explanation=f"{type(arg.value)} has a C/C++ based repr method. This is not supported.",
- hints=["Write the repr method in Python"],
- )
- else:
- bound_method = repr_method.__func__
- fn_vt = VariableTracker.build(tx, bound_method)
- return fn_vt.call_function(tx, [arg], {})
- if isinstance(arg, variables.UserDefinedClassVariable):
- if type(arg.value).__repr__ is type.__repr__:
- return variables.ConstantVariable.create(repr(arg.value))
- if isinstance(
- arg,
- (
- RangeVariable,
- ConstDictVariable,
- DefaultDictVariable,
- OrderedSetClassVariable,
- DictViewVariable,
- ),
- ):
- return variables.ConstantVariable.create(arg.debug_repr())
- return None
- def call_str(
- self, tx: "InstructionTranslator", arg: VariableTracker
- ) -> VariableTracker | None:
- # Handle `str` on a user defined function or object
- if isinstance(arg, (variables.UserFunctionVariable)):
- return variables.ConstantVariable.create(value=str(arg.fn))
- elif isinstance(arg, (variables.UserDefinedObjectVariable)):
- # Check if object has __str__ method
- if hasattr(arg.value, "__str__"):
- str_method = arg.value.__str__
- elif hasattr(arg.value, "__repr__"):
- # account for __repr__ functions when __str__ is absent
- str_method = arg.value.__repr__
- else:
- unimplemented(
- gb_type="failed to call str() on user defined object",
- context=str(arg),
- explanation="User defined object has no __str__ or __repr__ method",
- hints=[*graph_break_hints.USER_ERROR],
- )
- if type(arg.value).__str__ is object.__str__:
- # Rely on the object str method
- try:
- # pyrefly: ignore [unbound-name]
- return variables.ConstantVariable.create(value=str_method())
- except AttributeError:
- # Graph break
- return None
- # pyrefly: ignore [unbound-name]
- elif is_wrapper_or_member_descriptor(str_method):
- unimplemented(
- gb_type="Attempted to a str() method implemented in C/C++",
- context="",
- explanation=f"{type(arg.value)} has a C/C++ based str method. This is not supported.",
- hints=["Write the str method in Python"],
- )
- else:
- # Overrides for custom str method
- # Pass method as function to call tx.inline_user_function_return
- bound_method = str_method.__func__ # type: ignore[attr-defined]
- try:
- # Only supports certain function types
- user_func_variable = VariableTracker.build(tx, bound_method)
- except AssertionError:
- # Won't be able to do inline the str method, return to avoid graph break
- log.warning("Failed to create UserFunctionVariable", exc_info=True)
- return None
- # Inline the user function
- return user_func_variable.call_function(tx, [arg], {})
- elif isinstance(arg, (variables.ExceptionVariable,)):
- if len(arg.args) == 0:
- value = f"{arg.exc_type}"
- else:
- value = ", ".join(a.as_python_constant() for a in arg.args)
- return variables.ConstantVariable.create(value=value)
- return None
- def _call_min_max(
- self, tx: "InstructionTranslator", *args: VariableTracker
- ) -> VariableTracker | None:
- if len(args) == 1 and args[0].has_force_unpack_var_sequence(tx):
- items = args[0].force_unpack_var_sequence(tx)
- return self._call_min_max_seq(tx, items)
- elif len(args) == 2:
- return self._call_min_max_binary(tx, args[0], args[1])
- elif len(args) > 2:
- return self._call_min_max_seq(tx, args)
- return None
- def _call_min_max_seq(
- self, tx: "InstructionTranslator", items: Sequence[VariableTracker]
- ) -> VariableTracker:
- assert len(items) > 0
- if len(items) == 1:
- return items[0]
- return functools.reduce(functools.partial(self._call_min_max_binary, tx), items) # type: ignore[arg-type,return-value]
- def _call_min_max_binary(
- self,
- tx: "InstructionTranslator",
- a: VariableTracker | None,
- b: VariableTracker | None,
- ) -> VariableTracker | None:
- if a is None or b is None:
- # a or b could be none if we reduce and _call_min_max_binary failed
- # to return something
- return None
- if self.tensor_args(a, b):
- if not a.is_tensor():
- a, b = b, a
- assert a.is_tensor()
- # result of an item call is a scalar convert to a tensor
- if isinstance(a, FakeItemVariable):
- a = variables.TorchInGraphFunctionVariable(torch.tensor).call_function(
- tx, [a], {}
- )
- # Dynamic input does not get resolved, rather, gets stored as call_function
- if isinstance(a, SymNodeVariable) or isinstance(b, SymNodeVariable):
- from .builder import wrap_fx_proxy_cls
- return wrap_fx_proxy_cls(
- type(a),
- tx=tx,
- proxy=tx.output.create_proxy(
- "call_function",
- self.fn,
- *proxy_args_kwargs([a, b], {}),
- ),
- )
- # convert min/max to torch ops
- if b.is_python_constant():
- fn: VariableTracker
- if isinstance(a, variables.NumpyNdarrayVariable):
- import numpy as np
- fn = variables.NumpyVariable(np.clip)
- else:
- fn = variables.TorchInGraphFunctionVariable(torch.clamp)
- kwargs = {"min": b} if (self.fn is max) else {"max": b}
- result = fn.call_function(tx, [a], kwargs)
- else:
- if isinstance(a, variables.NumpyNdarrayVariable):
- import numpy as np
- np_fn = {max: np.maximum, min: np.minimum}[self.fn]
- fn = variables.NumpyVariable(np_fn)
- else:
- torch_fn = {max: torch.maximum, min: torch.minimum}[self.fn]
- fn = variables.TorchInGraphFunctionVariable(torch_fn)
- result = fn.call_function(tx, [a, b], {})
- # return unspec if both a, b are unspec or const
- if all(
- isinstance(
- i,
- (
- variables.UnspecializedPythonVariable,
- variables.ConstantVariable,
- ),
- )
- for i in [a, b]
- ):
- if any(isinstance(val, FakeItemVariable) for val in [a, b]):
- # type: ignore[arg-type]
- return variables.FakeItemVariable.from_tensor_variable(result)
- if b.is_python_constant():
- raw_b = b.as_python_constant()
- else:
- raw_b = b.raw_value # type: ignore[attr-defined]
- if self.fn is max:
- raw_res = max(a.raw_value, raw_b) # type: ignore[attr-defined]
- else:
- raw_res = min(a.raw_value, raw_b) # type: ignore[attr-defined]
- need_unwrap = any(
- x.need_unwrap
- for x in [a, b]
- if isinstance(x, variables.UnspecializedPythonVariable)
- )
- return variables.UnspecializedPythonVariable.from_tensor_variable(
- result, # type: ignore[arg-type]
- raw_res,
- need_unwrap,
- )
- # otherwise return tensor
- else:
- return result
- elif isinstance(a, SymNodeVariable) or isinstance(b, SymNodeVariable):
- py_fn = torch.sym_max if self.fn is max else torch.sym_min
- proxy = tx.output.create_proxy(
- "call_function", py_fn, *proxy_args_kwargs([a, b], {})
- )
- return SymNodeVariable.create(tx, proxy, None)
- elif isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable):
- value = self.fn(
- a.as_python_constant(),
- b.as_python_constant(),
- )
- return ConstantVariable.create(value)
- return None
- call_min = _call_min_max
- call_max = _call_min_max
- def call_abs(
- self, tx: "InstructionTranslator", arg: VariableTracker
- ) -> VariableTracker:
- from .builder import SourcelessBuilder
- # Call arg.__abs__()
- abs_method = SourcelessBuilder.create(tx, getattr).call_function(
- tx, [arg, ConstantVariable.create("__abs__")], {}
- )
- return abs_method.call_function(tx, [], {})
- def call_pos(
- self, tx: "InstructionTranslator", arg: VariableTracker
- ) -> VariableTracker:
- from .builder import SourcelessBuilder
- # Call arg.__pos__()
- pos_method = SourcelessBuilder.create(tx, getattr).call_function(
- tx, [arg, ConstantVariable.create("__pos__")], {}
- )
- return pos_method.call_function(tx, [], {})
- def call_index(
- self, tx: "InstructionTranslator", arg: VariableTracker
- ) -> VariableTracker:
- if arg.is_tensor():
- unimplemented(
- gb_type="unsupported index(Tensor)",
- context="",
- explanation="Dynamo does not support tracing builtin index() on a Tensor",
- hints=[],
- )
- arg = guard_if_dyn(arg)
- constant_value = operator.index(arg)
- return variables.ConstantVariable.create(constant_value)
- def call_round(
- self,
- tx: "InstructionTranslator",
- arg: VariableTracker,
- *args: VariableTracker,
- **kwargs: VariableTracker,
- ) -> VariableTracker:
- from .builder import SourcelessBuilder
- # Call arg.__round__()
- round_method = SourcelessBuilder.create(tx, getattr).call_function(
- tx, [arg, ConstantVariable.create("__round__")], {}
- )
- return round_method.call_function(tx, args, kwargs)
- def call_range(
- self, tx: "InstructionTranslator", *args: VariableTracker
- ) -> VariableTracker | None:
- if check_unspec_or_constant_args(args, {}):
- return variables.RangeVariable(args)
- elif self._dynamic_args(*args):
- args = tuple(
- variables.ConstantVariable.create(guard_if_dyn(arg)) for arg in args
- )
- return variables.RangeVariable(args)
- # None no-ops this handler and lets the driving function proceed
- return None
- def _dynamic_args(self, *args: VariableTracker, **kwargs: VariableTracker) -> bool:
- return any(isinstance(x, SymNodeVariable) for x in args) or any(
- isinstance(x, SymNodeVariable) for x in kwargs.values()
- )
- def call_slice(
- self, tx: "InstructionTranslator", *args: VariableTracker
- ) -> VariableTracker:
- return variables.SliceVariable(args, tx)
- def _dyn_proxy(
- self, tx: "InstructionTranslator", *args: Any, **kwargs: Any
- ) -> VariableTracker:
- from .builder import wrap_fx_proxy
- return wrap_fx_proxy(
- tx,
- tx.output.create_proxy(
- "call_function", self.fn, *proxy_args_kwargs(args, kwargs)
- ),
- )
- # NOTE must handle IteratorVariable separately!
- def _call_iter_tuple_list(
- self,
- tx: "InstructionTranslator",
- obj: VariableTracker | None = None,
- *args: VariableTracker,
- **kwargs: VariableTracker,
- ) -> VariableTracker | None:
- assert not isinstance(obj, variables.IteratorVariable)
- if self._dynamic_args(*args, **kwargs):
- return self._dyn_proxy(tx, *args, **kwargs)
- cls = variables.BaseListVariable.cls_for(self.fn)
- if obj is None:
- return cls(
- [],
- mutation_type=ValueMutationNew(),
- )
- elif obj.has_unpack_var_sequence(tx):
- if obj.source and not is_constant_source(obj.source):
- if isinstance(obj, TupleIteratorVariable):
- install_guard(
- obj.source.make_guard(GuardBuilder.TUPLE_ITERATOR_LEN)
- )
- else:
- if (
- getattr(obj, "source", False)
- and isinstance(obj, ConstDictVariable)
- and not istype(obj, (SetVariable, FrozensetVariable))
- ):
- tx.output.guard_on_key_order.add(obj.source)
- if isinstance(obj, variables.MappingProxyVariable):
- # This could be an overguarding, but its rare to iterate
- # through a mapping proxy and not use the keys.
- install_guard(
- obj.source.make_guard(GuardBuilder.MAPPING_KEYS_CHECK)
- )
- elif not isinstance(obj, variables.UnspecializedNNModuleVariable):
- # Prevent calling __len__ method for guards, the tracing
- # of __iter__ will insert the right guards later.
- install_guard(
- obj.source.make_guard(GuardBuilder.SEQUENCE_LENGTH)
- )
- return cls(
- list(obj.unpack_var_sequence(tx)),
- mutation_type=ValueMutationNew(),
- )
- return None
- def _call_iter_tuple_generator(
- self,
- tx: "InstructionTranslator",
- obj: VariableTracker,
- *args: VariableTracker,
- **kwargs: VariableTracker,
- ) -> VariableTracker:
- cls = variables.BaseListVariable.cls_for(self.fn)
- return cls(
- list(obj.force_unpack_var_sequence(tx)), # exhaust generator
- mutation_type=ValueMutationNew(),
- )
- def _call_tuple_list(
- self,
- tx: "InstructionTranslator",
- obj: VariableTracker | None = None,
- *args: VariableTracker,
- **kwargs: VariableTracker,
- ) -> VariableTracker | None:
- if isinstance(obj, variables.IteratorVariable):
- cls = variables.BaseListVariable.cls_for(self.fn)
- return cls(
- list(obj.force_unpack_var_sequence(tx)),
- mutation_type=ValueMutationNew(),
- )
- elif isinstance(obj, variables.LocalGeneratorObjectVariable) or (
- isinstance(obj, UserDefinedObjectVariable)
- and obj.has_force_unpack_var_sequence(tx)
- ):
- return self._call_iter_tuple_generator(tx, obj, *args, **kwargs)
- else:
- return self._call_iter_tuple_list(tx, obj, *args, **kwargs)
- def call_iter(
- self,
- tx: "InstructionTranslator",
- obj: VariableTracker,
- *args: VariableTracker,
- **kwargs: VariableTracker,
- ) -> VariableTracker:
- # avoid the overhead of tracing the polyfill if we already know the class implemented __iter__
- if isinstance(
- obj,
- (
- variables.ListVariable,
- variables.RangeVariable,
- variables.IteratorVariable,
- variables.ConstDictVariable,
- variables.NNModuleVariable,
- variables.TensorVariable,
- variables.TupleVariable,
- DictViewVariable,
- ),
- ):
- return obj.call_method(tx, "__iter__", [], {})
- else:
- # If the object doesn't implement a __iter__ method, it will be an error in eager mode when calling iter on it anyway.
- # If the object implements a __iter__ method, inlining effectively forwards the call to another iter call
- # (e.g. when __iter__ just returns iter(self.list)) or return a user-defined iterator.
- # If the object implements a __getitem__ method, iter(...) will call obj.__getitem__()
- # with an integer argument starting at 0, until __getitem__ raises IndexError
- ret = variables.UserFunctionVariable(
- polyfills.builtins.iter_ # type: ignore[arg-type]
- ).call_function(tx, [obj, *args], {})
- if args:
- # iter(obj, sentinel) returns an object that implements
- # __iter__ and __next__ methods (UserDefinedObjectVariable)
- # Wrap the return value in a IteratorVariable subclass (LazyObjectIteratorVariable)
- # that forwards the next_variable call to the object.
- ret = variables.ObjectIteratorVariable(ret)
- return ret
- call_tuple = _call_tuple_list
- call_list = _call_tuple_list
- def call_callable(
- self, tx: "InstructionTranslator", arg: VariableTracker
- ) -> VariableTracker | None:
- from .functions import BaseUserFunctionVariable, FunctoolsPartialVariable
- from .nn_module import NNModuleVariable
- if isinstance(
- arg,
- (
- variables.UserDefinedClassVariable,
- BaseUserFunctionVariable,
- FunctoolsPartialVariable,
- NNModuleVariable,
- ),
- ):
- return variables.ConstantVariable.create(True)
- elif isinstance(arg, UserDefinedVariable):
- return variables.ConstantVariable.create(callable(arg.value))
- elif isinstance(
- arg,
- (
- ConstantVariable,
- SymNodeVariable,
- TensorVariable,
- ListVariable,
- TupleVariable,
- ListIteratorVariable,
- ),
- ):
- return variables.ConstantVariable.create(False)
- else:
- return None
- def call_cast(
- self, _: Any, *args: VariableTracker, **kwargs: VariableTracker
- ) -> VariableTracker | None:
- if len(args) == 2:
- return args[1]
- unimplemented(
- gb_type="bad args to builtin cast()",
- context=f"got args {args} {kwargs}",
- explanation="Dynamo expects exactly 2 args to builtin cast().",
- hints=["Ensure your call to cast() has exactly 2 arguments."],
- )
- def call_dir(
- self, tx: "InstructionTranslator", arg: VariableTracker
- ) -> VariableTracker | None:
- if isinstance(arg, variables.UserDefinedClassVariable):
- return VariableTracker.build(tx, dir(arg.value))
- if isinstance(arg, BuiltinVariable):
- return VariableTracker.build(tx, dir(arg.fn))
- return None
- def call_dict(
- self,
- tx: "InstructionTranslator",
- /,
- *args: VariableTracker,
- **kwargs: VariableTracker,
- ) -> VariableTracker:
- return BuiltinVariable.call_custom_dict(tx, dict, *args, **kwargs)
- @staticmethod
- def call_custom_dict(
- tx: "InstructionTranslator",
- user_cls: type,
- /,
- *args: VariableTracker,
- **kwargs: VariableTracker,
- ) -> VariableTracker:
- args_list = list(args)
- if (
- len(args_list) == 1
- and isinstance(args_list[0], variables.GetAttrVariable)
- and isinstance(args_list[0].obj, variables.UserDefinedClassVariable)
- and not tx.output.side_effects.has_pending_mutation(args_list[0].obj)
- ):
- # Forward the GetAttrVariable(foo, "__dict__") to a realized vt of
- # VT(foo.__dict__). This simplifies the construction of the new
- # dict.
- args_list[0] = args_list[0].get_forwarded_dict(tx)
- return tx.inline_user_function_return(
- VariableTracker.build(tx, polyfills.construct_dict),
- [VariableTracker.build(tx, user_cls), *args_list],
- kwargs,
- )
- @staticmethod
- def call_custom_dict_fromkeys(
- tx: "InstructionTranslator",
- user_cls: type,
- /,
- *args: VariableTracker,
- **kwargs: VariableTracker,
- ) -> VariableTracker:
- if user_cls not in {dict, OrderedDict, defaultdict}:
- unimplemented(
- gb_type="Unsupported dict type for fromkeys()",
- context=f"{user_cls.__name__}.fromkeys(): {args} {kwargs}",
- explanation=f"Failed to call {user_cls.__name__}.fromkeys() because "
- f"{user_cls.__name__} is not any type of dict, OrderedDict, or defaultdict",
- hints=[
- f"Ensure {user_cls.__name__} is a type of dict, OrderedDict, or defaultdict.",
- ],
- )
- if kwargs:
- # Only `OrderedDict.fromkeys` accepts `value` passed by keyword
- if (
- user_cls is not OrderedDict
- or len(args) != 1
- or len(kwargs) != 1
- or "value" not in kwargs
- ):
- raise_args_mismatch(
- tx,
- f"{user_cls.__name__}.fromkeys",
- "1 args and 1 kwargs (`value`)",
- f"{len(args)} args and {len(kwargs)} kwargs",
- )
- args = (*args, kwargs.pop("value"))
- if len(args) == 0:
- raise_args_mismatch(
- tx,
- f"{user_cls.__name__}.fromkeys",
- "at least 1 args",
- f"{len(args)} args",
- )
- if len(args) == 1:
- args = (*args, CONSTANT_VARIABLE_NONE)
- if len(args) != 2:
- raise_args_mismatch(
- tx,
- f"{user_cls.__name__}.fromkeys",
- "2 args",
- f"{len(args)} args",
- )
- arg, value = args
- DictVariableType = (
- ConstDictVariable if user_cls is not defaultdict else DefaultDictVariable
- )
- if isinstance(arg, dict):
- arg_list = [ConstantVariable.create(k) for k in arg]
- return DictVariableType(
- dict.fromkeys(arg_list, value),
- user_cls,
- mutation_type=ValueMutationNew(),
- )
- elif arg.has_force_unpack_var_sequence(tx):
- keys = arg.force_unpack_var_sequence(tx)
- if all(is_hashable(v) for v in keys):
- return DictVariableType(
- dict.fromkeys(keys, value),
- user_cls,
- mutation_type=ValueMutationNew(),
- )
- unimplemented(
- gb_type="failed to call dict.fromkeys()",
- context=f"{user_cls.__name__}.fromkeys(): {args} {kwargs}",
- explanation=f"Failed to call {user_cls.__name__}.fromkeys() because "
- "arguments could not be automatically converted to a list, "
- "or some dict key is not hashable.",
- hints=[
- "Manually convert the argument to a list.",
- "Ensure all keys are hashable.",
- ],
- )
- def call_set(
- self,
- tx: "InstructionTranslator",
- *args: VariableTracker,
- **kwargs: VariableTracker,
- ) -> VariableTracker:
- from .builder import SourcelessBuilder
- # Can we merge this implementation and call_dict's one?
- assert not kwargs
- if not args:
- return SetVariable([], mutation_type=ValueMutationNew())
- if len(args) != 1:
- raise_observed_exception(
- TypeError,
- tx,
- args=[
- ConstantVariable.create(
- f"set() takes 1 positional argument but {len(args)} were given"
- )
- ],
- )
- arg = args[0]
- if istype(arg, variables.SetVariable):
- return arg.clone(mutation_type=ValueMutationNew())
- elif arg.has_force_unpack_var_sequence(tx):
- items = arg.force_unpack_var_sequence(tx)
- return SetVariable(items, mutation_type=ValueMutationNew())
- elif isinstance(arg, variables.UserDefinedObjectVariable) and isinstance(
- arg.value, KeysView
- ):
- iter_fn = arg.var_getattr(tx, "__iter__")
- if isinstance(iter_fn, variables.UserMethodVariable):
- out = tx.inline_user_function_return(iter_fn, args, kwargs)
- if isinstance(out, SetVariable):
- return out
- return SourcelessBuilder.create(tx, set).call_set(tx, out)
- raise_observed_exception(
- TypeError,
- tx,
- args=[ConstantVariable.create("failed to construct builtin set()")],
- )
- def call_frozenset(
- self,
- tx: "InstructionTranslator",
- *args: VariableTracker,
- **kwargs: VariableTracker,
- ) -> VariableTracker:
- assert not kwargs
- if not args:
- return FrozensetVariable([])
- if len(args) != 1:
- raise_observed_exception(
- TypeError,
- tx,
- args=[
- ConstantVariable.create(
- f"frozenset() takes 1 positional argument but {len(args)} were given"
- )
- ],
- )
- arg = args[0]
- if istype(arg, variables.FrozensetVariable):
- return FrozensetVariable([x.vt for x in arg.set_items])
- elif arg.has_force_unpack_var_sequence(tx):
- items = arg.force_unpack_var_sequence(tx)
- return FrozensetVariable(items)
- raise_observed_exception(
- TypeError,
- tx,
- args=[ConstantVariable.create("failed to construct builtin frozenset()")],
- )
- def call_zip(
- self,
- tx: "InstructionTranslator",
- *args: VariableTracker,
- **kwargs: VariableTracker,
- ) -> VariableTracker:
- from .builder import SourcelessBuilder
- if kwargs:
- if not (len(kwargs) == 1 and "strict" in kwargs):
- raise_args_mismatch(
- tx,
- "zip",
- "1 kwargs (`strict`)",
- f"{len(kwargs)} kwargs",
- )
- strict = kwargs.pop("strict", ConstantVariable.create(False))
- iter_args = [
- SourcelessBuilder.create(tx, iter).call_function(tx, [arg], {})
- for arg in args
- ]
- return variables.ZipVariable(
- iter_args,
- strict=strict.as_python_constant(),
- mutation_type=ValueMutationNew(),
- )
- def call_len(
- self,
- tx: "InstructionTranslator",
- *args: VariableTracker,
- **kwargs: VariableTracker,
- ) -> VariableTracker:
- try:
- return args[0].call_method(tx, "__len__", list(args[1:]), kwargs)
- except AttributeError as e:
- raise_observed_exception(type(e), tx, args=list(e.args))
- def call_getitem(
- self,
- tx: "InstructionTranslator",
- *args: VariableTracker,
- **kwargs: VariableTracker,
- ) -> VariableTracker:
- return args[0].call_method(tx, "__getitem__", list(args[1:]), kwargs)
- def call_isinstance(
- self,
- tx: "InstructionTranslator",
- arg: VariableTracker,
- isinstance_type_var: VariableTracker,
- ) -> VariableTracker:
- try:
- arg_type = arg.python_type()
- except NotImplementedError:
- unimplemented(
- gb_type="builtin isinstance() cannot determine type of argument",
- context=f"isinstance({arg}, {isinstance_type_var})",
- explanation=f"Dynamo doesn't have a rule to determine the type of argument {arg}",
- hints=[*graph_break_hints.DYNAMO_BUG],
- )
- isinstance_type = isinstance_type_var.as_python_constant()
- if isinstance(arg, variables.TensorVariable) and arg.dtype is not None:
- def _tensor_isinstance(
- tensor_var: VariableTracker, tensor_type: Any
- ) -> bool:
- def check_type(ty: Any) -> bool:
- if ty not in tensortype_to_dtype:
- example_val = arg.as_proxy().node.meta["example_value"]
- if (
- is_traceable_wrapper_subclass(example_val)
- and ty is torch.nn.parameter.Parameter
- ):
- # N.B: we are calling isinstance directly on the example value.
- # torch.nn.Parameter has a meta-class that overrides __isinstance__,
- # the isinstance check here allows us to invoke that logic.
- return isinstance(example_val, ty)
- else:
- return issubclass(arg.python_type(), ty)
- dtypes = tensortype_to_dtype[ty]
- # pyrefly: ignore [missing-attribute]
- return arg.dtype in dtypes
- if type(tensor_type) is tuple:
- return any(check_type(ty) for ty in tensor_type)
- else:
- return check_type(tensor_type)
- return variables.ConstantVariable.create(
- _tensor_isinstance(arg, isinstance_type)
- )
- # UserDefinedObject with C extensions can have torch.Tensor attributes,
- # so break graph.
- if isinstance(arg, variables.UserDefinedObjectVariable) and isinstance(
- arg.value, types.MemberDescriptorType
- ):
- unimplemented(
- gb_type="isinstance() called on user defined object with C extensions",
- context=f"isinstance({arg}, {isinstance_type})",
- explanation="User-defined object with C extensions can have torch.Tensor "
- "attributes; intentionally graph breaking.",
- hints=[*graph_break_hints.SUPPORTABLE],
- )
- # handle __instancecheck__ defined in user class
- if (
- isinstance(arg, variables.UserDefinedObjectVariable)
- and "__instancecheck__" in isinstance_type.__class__.__dict__
- ):
- return variables.ConstantVariable.create(
- isinstance_type.__class__.__instancecheck__(isinstance_type, arg.value)
- )
- if isinstance(arg, variables.UserDefinedExceptionClassVariable):
- # pyrefly: ignore [unbound-name]
- return ConstantVariable.create(isinstance(arg_type, isinstance_type))
- isinstance_type_tuple: tuple[type, ...]
- if isinstance(isinstance_type, type) or callable(
- # E.g. isinstance(obj, typing.Sequence)
- getattr(isinstance_type, "__instancecheck__", None)
- ):
- isinstance_type_tuple = (isinstance_type,)
- elif isinstance(isinstance_type, types.UnionType):
- isinstance_type_tuple = isinstance_type.__args__
- elif isinstance(isinstance_type, tuple) and all(
- isinstance(tp, type) or callable(getattr(tp, "__instancecheck__", None))
- for tp in isinstance_type
- ):
- isinstance_type_tuple = isinstance_type
- else:
- raise_observed_exception(
- TypeError,
- tx,
- args=[
- "isinstance() arg 2 must be a type, a tuple of types, or a union"
- ],
- )
- try:
- # NB: `isinstance()` does not call `__subclasscheck__` but use `__instancecheck__`.
- # But usually `isinstance(obj, type_info)` and `issubclass(type(obj), type_info)` gives
- # the same result.
- # WARNING: This might run arbitrary user code `__subclasscheck__` and we did not trace
- # through it. This is a limitation of the current implementation.
- # Usually `__subclasscheck__` and `__instancecheck__` can be constant fold through, it
- # might not be a big issue and we trade off it for performance.
- # pyrefly: ignore [unbound-name]
- val = issubclass(arg_type, isinstance_type_tuple)
- except TypeError:
- # pyrefly: ignore [unbound-name]
- val = arg_type in isinstance_type_tuple
- return variables.ConstantVariable.create(val)
- def call_issubclass(
- self,
- tx: "InstructionTranslator",
- left_ty: VariableTracker,
- right_ty: VariableTracker,
- ) -> VariableTracker:
- """Checks if first arg is subclass of right arg"""
- try:
- left_ty_py = left_ty.as_python_constant()
- right_ty_py = right_ty.as_python_constant()
- except NotImplementedError:
- unimplemented(
- gb_type="issubclass() with non-constant arguments",
- context=f"issubclass({left_ty}, {right_ty})",
- explanation="issubclass() with non-constant arguments not supported.",
- hints=[
- "Make sure your arguments are types.",
- *graph_break_hints.USER_ERROR,
- ],
- )
- # WARNING: This might run arbitrary user code `__subclasscheck__`.
- # See the comment in call_isinstance above.
- # pyrefly: ignore [unbound-name]
- return variables.ConstantVariable(issubclass(left_ty_py, right_ty_py))
- def call_super(
- self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker
- ) -> VariableTracker:
- return variables.SuperVariable(a, b)
- def call_next(
- self, tx: "InstructionTranslator", *args: VariableTracker
- ) -> VariableTracker:
- arg = args[0]
- try:
- return arg.next_variable(tx)
- except ObservedUserStopIteration:
- if len(args) == 2:
- return args[1]
- raise
- except Unsupported as ex:
- if isinstance(arg, variables.BaseListVariable):
- ex.remove_from_stats()
- return arg.items[0]
- raise
- def call_hasattr(
- self, tx: "InstructionTranslator", obj: VariableTracker, attr: VariableTracker
- ) -> VariableTracker | None:
- if attr.is_python_constant():
- name = attr.as_python_constant()
- if isinstance(obj, variables.BuiltinVariable):
- return variables.ConstantVariable(hasattr(obj.fn, name))
- return obj.call_obj_hasattr(tx, name)
- return None
- def call_map(
- self,
- tx: "InstructionTranslator",
- fn: VariableTracker,
- *seqs: VariableTracker,
- **kwargs: VariableTracker,
- ) -> VariableTracker:
- strict = ConstantVariable.create(False)
- if kwargs:
- if sys.version_info >= (3, 14):
- if not (len(kwargs) == 1 and "strict" in kwargs):
- raise_args_mismatch(
- tx,
- "map",
- "1 kwargs (`strict`)",
- f"{len(kwargs)} kwargs",
- )
- strict = kwargs.pop("strict", ConstantVariable.create(False))
- else:
- raise_args_mismatch(
- tx,
- "map",
- "0 kwargs",
- f"{len(kwargs)} kwargs",
- )
- seq_list = [
- seq.unpack_var_sequence(tx) if seq.has_unpack_var_sequence(tx) else seq
- for seq in seqs
- ]
- return variables.MapVariable(
- fn,
- seq_list, # type: ignore[arg-type]
- strict=strict.as_python_constant(),
- mutation_type=ValueMutationNew(),
- )
- def call_filter(
- self, tx: "InstructionTranslator", fn: VariableTracker, seq: VariableTracker
- ) -> VariableTracker:
- seq_or_list = (
- seq.unpack_var_sequence(tx) if seq.has_unpack_var_sequence(tx) else seq
- )
- return variables.FilterVariable(
- fn,
- seq_or_list, # type: ignore[arg-type]
- mutation_type=ValueMutationNew(),
- )
- def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
- source = self.source and AttrSource(self.source, name)
- if self.fn is object:
- # for object, we can just directly read the attribute
- try:
- value = getattr(self.fn, name)
- except AttributeError:
- raise_observed_exception(AttributeError, tx)
- # pyrefly: ignore [unbound-name]
- if not callable(value):
- # pyrefly: ignore [unbound-name]
- return VariableTracker.build(tx, value, source)
- return variables.GetAttrVariable(self, name, source=source)
- def call_getattr(
- self,
- tx: "InstructionTranslator",
- obj: VariableTracker,
- name_var: VariableTracker,
- default: VariableTracker | None = None,
- ) -> VariableTracker | None:
- if not name_var.is_python_constant():
- unimplemented(
- gb_type="getattr() with non-constant name argument",
- context=f"getattr({obj}, {name_var}, {default})",
- explanation="getattr() with non-constant name argument is not supported",
- hints=["Ensure the name argument of getattr() is a string"],
- )
- name = name_var.as_python_constant()
- # See NOTE [Tensor "grad" and "_grad" attr]
- if obj.is_tensor() and name == "_grad":
- name = "grad"
- if tx.output.side_effects.is_attribute_mutation(obj):
- if isinstance(obj, variables.UnspecializedNNModuleVariable):
- if (
- name
- in (
- "named_parameters",
- "parameters",
- "named_buffers",
- "buffers",
- "named_modules",
- "modules",
- )
- and obj.is_state_mutated
- and tx.output.side_effects.has_pending_mutation(obj)
- ):
- unimplemented(
- gb_type="getattr() on nn.Module with pending mutation",
- context=f"getattr({obj}, {name}, {default})",
- explanation="Intentionally graph breaking on getattr() on a nn.Module "
- "with a pending mutation",
- hints=[],
- )
- if tx.output.side_effects.has_pending_mutation_of_attr(obj, name):
- return tx.output.side_effects.load_attr(obj, name)
- if default is not None:
- hasattr_var = self.call_hasattr(tx, obj, name_var)
- if hasattr_var is not None:
- assert hasattr_var.is_constant_match(True, False)
- if not hasattr_var.as_python_constant():
- return default
- else:
- return default
- source = obj.source and AttrSource(obj.source, name)
- if name in {"__bases__", "__base__", "__flags__"}:
- try:
- value = obj.as_python_constant()
- if isinstance(value, type):
- if name == "__bases__":
- tuple_args = [
- VariableTracker.build(
- tx, b, source and GetItemSource(source, i)
- )
- for i, b in enumerate(value.__bases__)
- ]
- return variables.TupleVariable(tuple_args, source=source)
- if name == "__base__":
- return VariableTracker.build(tx, value.__base__, source)
- if name == "__flags__":
- return ConstantVariable.create(value.__flags__)
- except NotImplementedError:
- pass
- if isinstance(obj, variables.NNModuleVariable):
- return obj.var_getattr(tx, name)
- elif isinstance(
- obj,
- (
- variables.TensorVariable,
- variables.NamedTupleVariable,
- variables.ConstantVariable,
- variables.DefaultDictVariable,
- variables.DistributedVariable,
- variables.UserDefinedClassVariable,
- variables.UserDefinedObjectVariable,
- ),
- ):
- if (
- isinstance(obj, variables.UserDefinedObjectVariable)
- and issubclass(obj.value.__class__, unittest.TestCase)
- and config.enable_trace_unittest
- and name
- in (
- "assertRaisesRegex",
- "assertNotWarns",
- "assertWarnsRegex",
- "assertWarns",
- )
- ):
- unimplemented(
- gb_type="Failed to trace unittest method",
- context=f"function: unittest.TestCase.{name}",
- explanation=f"Dynamo does not know how to trace unittest method `{name}` ",
- hints=[
- f"Avoid calling `TestCase.{name}`. "
- "Please report an issue to PyTorch.",
- ],
- )
- if obj.is_tensor():
- # pyrefly: ignore[missing-attribute]
- fake_val = obj.as_proxy().node.meta["example_value"]
- if (
- isinstance(fake_val, torch.Tensor)
- and is_sparse_any(fake_val)
- and (not tx.export or not config.capture_sparse_compute)
- ):
- unimplemented(
- gb_type="Attempted to wrap sparse Tensor",
- context="",
- explanation="torch.compile does not support sparse Tensors",
- hints=[*graph_break_hints.SPARSE_TENSOR],
- )
- try:
- return obj.var_getattr(tx, name)
- except AsPythonConstantNotImplementedError:
- # dont fallback on as_python_constant error because this leads
- # to a failure later on, and leads to a wrong stacktrace
- raise
- except NotImplementedError:
- return variables.GetAttrVariable(obj, name, source=source)
- elif isinstance(obj, variables.TorchInGraphFunctionVariable):
- # Get OpOverload from an OpOverloadPacket, e.g., torch.ops.aten.add.default.
- member = getattr(obj.value, name)
- if isinstance(
- member, (torch._ops.OpOverloadPacket, torch._ops.OpOverload)
- ) and torch._dynamo.trace_rules.is_aten_op_or_tensor_method(member):
- return variables.TorchInGraphFunctionVariable(member, source=source)
- elif name in cmp_name_to_op_mapping:
- return variables.GetAttrVariable(obj, name, source=source)
- else:
- return None
- elif isinstance(obj, DummyModule):
- # TODO(mlazos) - Do we need this?
- if obj.is_torch or name not in obj.value.__dict__:
- member = getattr(obj.value, name)
- else:
- member = obj.value.__dict__[name]
- if config.replay_record_enabled:
- tx.exec_recorder.record_module_access(obj.value, name, member) # type: ignore[arg-type, union-attr]
- return VariableTracker.build(tx, member, source)
- elif istype(obj, variables.UserFunctionVariable) and name in (
- "__name__",
- "__module__",
- ):
- return ConstantVariable.create(getattr(obj.fn, name))
- else:
- try:
- return obj.var_getattr(tx, name)
- except NotImplementedError:
- return variables.GetAttrVariable(obj, name, source=source)
- def call_setattr(
- self,
- tx: "InstructionTranslator",
- obj: VariableTracker,
- name_var: VariableTracker,
- val: VariableTracker,
- ) -> VariableTracker | None:
- if isinstance(
- obj,
- (
- variables.DefaultDictVariable,
- variables.PlacementVariable,
- variables.NamedTupleVariable,
- variables.UserDefinedObjectVariable,
- variables.NestedUserFunctionVariable,
- variables.ExceptionVariable,
- variables.TracebackVariable,
- ),
- ):
- return obj.call_method(tx, "__setattr__", [name_var, val], {})
- elif (
- tx.output.side_effects.is_attribute_mutation(obj)
- and name_var.is_python_constant()
- ):
- name = name_var.as_python_constant()
- if obj.is_tensor():
- from .builder import wrap_fx_proxy
- # Some special handling for tensor attributes.
- if name == "requires_grad":
- # TODO(voz): Make it work properly
- unimplemented(
- gb_type="setattr() on Tensor.requires_grad",
- context=f"setattr({obj}, {name}, {val})",
- explanation="setattr() on Tensor.requires_grad not supported. "
- "Mutating requires_grad can introduce a new leaf from non-leaf or vice versa in "
- "the middle of the graph, which AOTAutograd does not currently know how to handle.",
- hints=[*graph_break_hints.SUPPORTABLE],
- )
- elif name == "data":
- # See comments on `test_set_data_on_scoped_tensor` for plans
- # to support this.
- if obj.source is None:
- unimplemented(
- gb_type="Failed to mutate tensor data attribute",
- context=f"setattr({obj}, {name}, {val})",
- explanation="Dyanmo only supports mutating `.data`"
- " of tensor created outside `torch.compile` region",
- hints=[
- "Don't mutate `.data` on this tensor, or move "
- "the mutation out of `torch.compile` region",
- ],
- )
- elif obj.dtype != val.dtype: # type: ignore[attr-defined]
- unimplemented(
- gb_type="Failed to mutate tensor data attribute to different dtype",
- context=f"setattr({obj}, {name}, {val})",
- explanation="Dyanmo only supports mutating `.data`"
- " of tensor to a new one with the same dtype",
- hints=[
- "Don't mutate `.data` on this tensor, or move "
- "the mutation out of `torch.compile` region",
- ],
- )
- # Remove the old reference in tracked fakes - if we don't do this
- # new .data value size and shape differences will cause
- # tracked fakes to produce incorrect guards. This is sound because the TensorVariable
- # coming out of set_() below will be a new one, and get
- # installed in tracked fakes.
- to_remove = [
- tf for tf in tx.output.tracked_fakes if tf.source == obj.source
- ]
- for tf in to_remove:
- tx.output.tracked_fakes.remove(tf)
- # Step 1 - disable grads
- with dynamo_disable_grad(tx), torch.no_grad():
- # Step 2 - call `set_`
- out = wrap_fx_proxy(
- tx,
- tx.output.create_proxy(
- "call_function",
- torch.Tensor.set_,
- *proxy_args_kwargs([obj, val], {}),
- ),
- )
- # Step 3 - drop the version counter - this is a step required to get
- # .data setting to play correctly with the autograd engine.
- # Essentially, dynamo is trying to faithfully preserve the (absurd)
- # behavior of .data= from eager mode
- def _lower_version_count_by_1(x: torch.Tensor) -> torch.Tensor:
- version = x._version
- if version > 0:
- version = version - 1
- torch._C._autograd._unsafe_set_version_counter((x,), (version,))
- return x
- tx.output.create_proxy(
- "call_function",
- _lower_version_count_by_1,
- (out.as_proxy(),),
- {},
- )
- _lower_version_count_by_1(obj.as_proxy().node.meta["example_value"])
- # This handles options prop, guards and ends with a clone
- # Step 4 - replace all reference to the current object with the new one
- return out
- elif name in ("_grad", "grad"):
- # NOTE: [Tensor "grad" and "_grad" attr]
- # _grad and grad share the same setter/getter, see
- # THPVariable_properties, and here we make sure setting one
- # enables reading `val` from the other, by routing all
- # read/write to `grad`.
- name = "grad"
- elif is_tensor_getset_descriptor(name):
- # Attribute like `torch.Tensor.real` has special setters we
- # don't yet support; it's not as simple adding an entry to
- # the side effect mapping.
- unimplemented(
- gb_type="Failed to set tensor attribute",
- context=f"setattr({obj}, {name}, {val})",
- explanation="Dyanmo doesn't support setting these tensor attributes",
- hints=[
- f"Don't mutate attribute '{name}' on tensors, or "
- "move the mutation out of `torch.compile` region",
- ],
- )
- tx.output.side_effects.store_attr(obj, name, val)
- return val
- elif isinstance(obj, variables.NNModuleVariable):
- if not tx.output.is_root_tracer():
- unimplemented(
- gb_type="nn.Module mutation in HigherOrderOp",
- context=f"nn.Module: {obj}",
- explanation="Inplace modifying nn.Module params/buffers inside HigherOrderOps is not allowed.",
- hints=[
- "Remove the mutation or move it outside of the HigherOrderOp.",
- *graph_break_hints.FUNDAMENTAL,
- ],
- )
- if name_var.is_python_constant() and isinstance(
- val, variables.TensorVariable
- ):
- assigning_fake_val = get_fake_value(val.as_proxy().node, tx)
- try:
- getattr_var = obj.var_getattr(tx, name_var.as_python_constant())
- except (AttributeError, ObservedAttributeError):
- getattr_var = None
- if getattr_var is not None and getattr_var.is_tensor():
- # get_fake_val will get the same fake tensor
- existing_fake_attr = get_fake_value(getattr_var.as_proxy().node, tx)
- # same tensor identity, setattr is a no-op
- mod_setattr = inspect.getattr_static(obj.module_type, "__setattr__")
- if (
- existing_fake_attr is assigning_fake_val
- and mod_setattr is torch.nn.Module.__setattr__
- ):
- return getattr_var
- obj.convert_to_unspecialized(tx)
- return None
- def call_delattr(
- self,
- tx: "InstructionTranslator",
- obj: VariableTracker,
- name_var: VariableTracker,
- ) -> VariableTracker:
- return obj.call_method(tx, "__delattr__", [name_var], {})
- def call_type(
- self, tx: "InstructionTranslator", obj: VariableTracker
- ) -> VariableTracker:
- try:
- py_type = obj.python_type()
- except NotImplementedError as error:
- raise UserError(
- UserErrorType.INVALID_INPUT,
- str(error),
- case_name="unknown_python_type",
- ) from None
- source = obj.source and TypeSource(obj.source)
- if (
- source is None
- and isinstance(obj, variables.UserDefinedObjectVariable)
- and obj.cls_source
- ):
- source = obj.cls_source
- if py_type is torch.Tensor:
- # In some cases torch isn't available in globals
- name = tx.output.install_global_by_id("", torch)
- source = AttrSource(GlobalSource(name), "Tensor")
- return VariableTracker.build(tx, py_type, source)
- def call_reversed(
- self, tx: "InstructionTranslator", obj: VariableTracker
- ) -> VariableTracker | None:
- if obj.has_unpack_var_sequence(tx):
- items = list(reversed(obj.unpack_var_sequence(tx)))
- return variables.TupleVariable(items)
- return None
- def call_sorted(
- self,
- tx: "InstructionTranslator",
- obj: VariableTracker,
- **kwargs: VariableTracker,
- ) -> VariableTracker | None:
- if obj.has_force_unpack_var_sequence(tx) and not isinstance(
- obj, variables.TensorVariable
- ):
- list_var = variables.ListVariable(
- obj.force_unpack_var_sequence(tx),
- mutation_type=ValueMutationNew(),
- )
- list_var.call_method(tx, "sort", [], kwargs)
- return list_var
- return None
- # neg is a constant fold function, so we only get here if constant fold is not valid
- def call_neg(
- self, tx: "InstructionTranslator", a: VariableTracker
- ) -> VariableTracker | None:
- if isinstance(a, SymNodeVariable):
- return SymNodeVariable.create(
- tx,
- (operator.neg)(a.as_proxy()),
- sym_num=None,
- )
- if (
- isinstance(a, UserDefinedObjectVariable)
- and a.call_obj_hasattr(tx, "__neg__").value # type: ignore[attr-defined]
- ):
- return a.call_method(tx, "__neg__", [], {})
- # None no-ops this handler and lets the driving function proceed
- return None
- def call_format(
- self,
- tx: "InstructionTranslator",
- _format_string: VariableTracker,
- *args: VariableTracker,
- **kwargs: VariableTracker,
- ) -> VariableTracker:
- format_string = _format_string.as_python_constant()
- format_string = str(format_string)
- return variables.StringFormatVariable.create(format_string, args, kwargs)
- def call_id(
- self, tx: "InstructionTranslator", *args: VariableTracker
- ) -> VariableTracker:
- if len(args) > 0 and isinstance(args[0], variables.NNModuleVariable):
- nn_mod_variable = args[0]
- mod = tx.output.get_submodule(nn_mod_variable.module_key)
- return variables.ConstantVariable.create(id(mod))
- elif len(args) == 1 and isinstance(
- args[0],
- (variables.UserDefinedClassVariable, variables.UserDefinedObjectVariable),
- ):
- if args[0].source:
- if isinstance(args[0], variables.UserDefinedClassVariable):
- install_guard(args[0].source.make_guard(GuardBuilder.CLASS_MATCH))
- else:
- install_guard(args[0].source.make_guard(GuardBuilder.ID_MATCH))
- constant_result = id(args[0].value)
- return variables.ConstantVariable.create(constant_result)
- elif len(args) == 1 and args[0].is_tensor():
- tensor_variable = cast(TensorVariable, args[0])
- return tensor_variable.call_id(tx)
- elif istype(args[0], variables.UserFunctionVariable):
- return variables.ConstantVariable.create(id(args[0].fn))
- elif istype(args[0], variables.SkipFunctionVariable):
- return variables.ConstantVariable.create(id(args[0].value))
- elif istype(args[0], variables.FunctoolsPartialVariable):
- return variables.ConstantVariable.create(id(args[0].fake_value))
- else:
- unimplemented(
- gb_type="id() with unsupported args",
- context=str(args),
- explanation=f"Dynamo doesn't know how to trace id() call with args {args}",
- hints=[
- "Supported args are Tensors, and functions/nn.Modules/user-defined objects "
- "from outside the compiled region.",
- *graph_break_hints.SUPPORTABLE,
- ],
- )
- def call_deepcopy(
- self, tx: "InstructionTranslator", x: VariableTracker
- ) -> VariableTracker:
- unimplemented(
- gb_type="copy.deepcopy()",
- context=f"copy.deepcopy({x})",
- explanation="Dynamo does not support copy.deepcopy()",
- hints=[
- "Avoid calling copy.deepcopy()",
- *graph_break_hints.SUPPORTABLE,
- ],
- )
- def _comparison_with_tensor(
- self, tx: "InstructionTranslator", left: VariableTracker, right: VariableTracker
- ) -> VariableTracker:
- from .builder import wrap_fx_proxy_cls
- from .tensor import supported_tensor_comparison_op_values
- op = self.fn
- if op in [operator.is_, operator.is_not]:
- is_result = (
- left.is_tensor()
- and right.is_tensor()
- and id(extract_fake_example_value(left.as_proxy().node))
- == id(extract_fake_example_value(right.as_proxy().node))
- )
- if op is operator.is_:
- return ConstantVariable.create(is_result)
- else:
- return ConstantVariable.create(not is_result)
- if op not in supported_tensor_comparison_op_values:
- unimplemented(
- gb_type="unsupported Tensor comparison op",
- context=f"{op.__name__}({left}, {right})",
- explanation=f"Dynamo does not support the comparison op {op.__name__} "
- f"with Tensor arguments {left}, {right}",
- hints=[*graph_break_hints.SUPPORTABLE],
- )
- if (
- isinstance(left, TensorVariable)
- and isinstance(right, TensorVariable)
- and (left.size and right.size) is not None
- and left.size != right.size
- ):
- try:
- torch.broadcast_shapes(left.size, right.size)
- except RuntimeError:
- # not broadcastable, can't be compared
- unimplemented(
- gb_type="failed to broadcast when attempting Tensor comparison op",
- context=f"{op.__name__}({left}, {right})",
- explanation=f"Dynamo was unable to broad cast the arguments {left}, {right} "
- f"when attempting to trace the comparison op {op.__name__}.",
- hints=[*graph_break_hints.USER_ERROR],
- )
- tensor_cls = left if left.is_tensor() else right
- proxy = tx.output.create_proxy(
- "call_function", op, (left.as_proxy(), right.as_proxy()), {}
- )
- return wrap_fx_proxy_cls(
- type(tensor_cls), # handle Ndarrays and Tensors
- tx,
- proxy,
- )
- def _comparison_with_symnode(
- self, tx: "InstructionTranslator", left: VariableTracker, right: VariableTracker
- ) -> VariableTracker:
- from .tensor import supported_tensor_comparison_op_values
- op = self.fn
- if op not in supported_tensor_comparison_op_values:
- unimplemented(
- gb_type="unsupported SymNode comparison op",
- context=f"{op.__name__}({left}, {right})",
- explanation=f"Dynamo does not support the comparison op {op.__name__} "
- f"with SymNode arguments {left}, {right}",
- hints=[*graph_break_hints.SUPPORTABLE],
- )
- # This is seen in inspect signature where we check if the value is a default value
- if isinstance(right, variables.UserDefinedClassVariable):
- return variables.ConstantVariable(op(object(), None))
- proxy = tx.output.create_proxy(
- "call_function", op, (left.as_proxy(), right.as_proxy()), {}
- )
- return SymNodeVariable.create(
- tx,
- proxy,
- sym_num=None,
- )
- def call_xor(
- self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker
- ) -> VariableTracker | None:
- # Rely on constant_handler
- if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable):
- return None
- if a.is_symnode_like() and b.is_symnode_like():
- return SymNodeVariable.create(
- tx,
- tx.output.create_proxy(
- "call_function", operator.xor, *proxy_args_kwargs([a, b], {})
- ),
- sym_num=None,
- )
- if isinstance(
- a,
- (DictKeysVariable, SetVariable, UserDefinedObjectVariable),
- ):
- return a.call_method(tx, "__xor__", [b], {})
- return None
- def call_ixor(
- self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker
- ) -> VariableTracker | None:
- if isinstance(a, (DictKeysVariable, SetVariable, UserDefinedObjectVariable)):
- return a.call_method(tx, "__ixor__", [b], {})
- return None
- def call_sub(
- self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker
- ) -> VariableTracker | None:
- if isinstance(a, (DictKeysVariable, SetVariable, UserDefinedObjectVariable)):
- return a.call_method(tx, "__sub__", [b], {})
- return None
- def call_isub(
- self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker
- ) -> VariableTracker | None:
- if isinstance(a, (DictKeysVariable, SetVariable, UserDefinedObjectVariable)):
- return a.call_method(tx, "__isub__", [b], {})
- return None
- def call_and_(
- self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker
- ) -> VariableTracker | None:
- # Rely on constant_handler
- if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable):
- return None
- if a.is_symnode_like() and b.is_symnode_like():
- return SymNodeVariable.create(
- tx,
- tx.output.create_proxy(
- "call_function", operator.and_, *proxy_args_kwargs([a, b], {})
- ),
- sym_num=None,
- )
- if isinstance(a, (DictKeysVariable, SetVariable, UserDefinedObjectVariable)):
- return a.call_method(tx, "__and__", [b], {})
- # None no-ops this handler and lets the driving function proceed
- return None
- def call_iand(
- self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker
- ) -> VariableTracker | None:
- # Rely on constant_handler
- if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable):
- return None
- if a.is_symnode_like() and b.is_symnode_like():
- return SymNodeVariable.create(
- tx,
- tx.output.create_proxy(
- "call_function", operator.iand, *proxy_args_kwargs([a, b], {})
- ),
- sym_num=None,
- )
- if isinstance(a, (DictKeysVariable, SetVariable, UserDefinedObjectVariable)):
- return a.call_method(tx, "__iand__", [b], {})
- return None
- def call_or_(
- self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker
- ) -> VariableTracker | None:
- # Rely on constant_handler
- if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable):
- return None
- if a.is_symnode_like() and b.is_symnode_like():
- return SymNodeVariable.create(
- tx,
- tx.output.create_proxy(
- "call_function", operator.or_, *proxy_args_kwargs([a, b], {})
- ),
- sym_num=None,
- )
- # This call looks like `{"one": torch.ones(1)} | {"two": torch.ones(2)}`.
- if isinstance(
- a,
- (
- ConstDictVariable,
- DictKeysVariable,
- MutableMappingVariable,
- SetVariable,
- UserDefinedDictVariable,
- UserDefinedObjectVariable,
- ),
- ):
- # TODO(guilhermeleobas): forward the call to b.__ror__(a) if
- # a.__ror__(b) returns NotImplemented
- return a.call_method(tx, "__or__", [b], {})
- # None no-ops this handler and lets the driving function proceed
- return None
- def call_ior(
- self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker
- ) -> VariableTracker | None:
- # Rely on constant_handler
- if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable):
- return None
- if a.is_symnode_like() and b.is_symnode_like():
- return SymNodeVariable.create(
- tx,
- tx.output.create_proxy(
- "call_function", operator.ior, *proxy_args_kwargs([a, b], {})
- ),
- sym_num=None,
- )
- # This call looks like `{"one": torch.ones(1)} |= {"two": torch.ones(2)}`.
- if isinstance(
- a,
- (
- ConstDictVariable,
- DictKeysVariable,
- MutableMappingVariable,
- SetVariable,
- UserDefinedObjectVariable,
- ),
- ):
- return a.call_method(tx, "__ior__", [b], {})
- # None no-ops this handler and lets the driving function proceed
- return None
- def call_not_(
- self, tx: "InstructionTranslator", a: VariableTracker
- ) -> VariableTracker | None:
- if isinstance(a, SymNodeVariable):
- return SymNodeVariable.create(
- tx,
- tx.output.create_proxy(
- "call_function", operator.not_, *proxy_args_kwargs([a], {})
- ),
- sym_num=None,
- )
- # Unwrap the underlying ConstDictVariable
- if isinstance(a, DictViewVariable):
- a = a.dv_dict
- if isinstance(a, (ListVariable, ConstDictVariable)):
- return ConstantVariable.create(len(a.items) == 0)
- return None
- def call_contains(
- self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker
- ) -> VariableTracker:
- return a.call_method(tx, "__contains__", [b], {})
- def is_python_hashable(self) -> Literal[True]:
- return True
- def get_python_hash(self) -> int:
- return hash(self.fn)
- def is_python_equal(self, other: object) -> bool:
- return isinstance(other, variables.BuiltinVariable) and self.fn is other.fn
- @contextlib.contextmanager
- def dynamo_disable_grad(tx: "InstructionTranslator") -> typing.Iterator[None]:
- from . import GradModeVariable
- gmv = GradModeVariable.create(tx, False)
- try:
- gmv.enter(tx)
- yield
- finally:
- gmv.exit(tx)
|