misc.py 94 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445
  1. """
  2. This module contains miscellaneous variable tracker implementations for various Python types
  3. and features used in Dynamo's symbolic execution. These classes help track and propagate
  4. information about different kinds of variables during graph capture.
  5. Key classes include:
  6. - SuperVariable: Handles super() calls and method resolution
  7. - ExceptionVariable: Tracks exception objects
  8. - RandomVariable: Manages random number generators
  9. - GetAttrVariable: Tracks attribute access
  10. - MethodWrapperVariable: Handles method wrappers
  11. - PythonModuleVariable: Tracks Python modules
  12. - NumpyVariable: Handles numpy functions and types
  13. - StringFormatVariable: Manages string formatting
  14. - DebuggingVariable: Handles print and logging
  15. """
  16. import dataclasses
  17. import enum
  18. import functools
  19. import inspect
  20. import itertools
  21. import logging
  22. import random
  23. import re
  24. import sys
  25. import traceback
  26. import types
  27. import weakref
  28. from collections.abc import Callable, Sequence
  29. from random import Random
  30. from types import BuiltinFunctionType
  31. from typing import Any, Literal, NoReturn, TYPE_CHECKING, TypeGuard, Union
  32. import torch._C
  33. import torch._numpy as tnp
  34. import torch.utils._pytree as pytree
  35. from torch._dynamo.variables.base import MutationType
  36. from torch._dynamo.variables.lists import TupleVariable
  37. from torch._guards import Source
  38. from .. import config, graph_break_hints, trace_rules, variables
  39. from ..bytecode_transformation import (
  40. create_call_function,
  41. create_call_function_ex,
  42. create_instruction,
  43. )
  44. from ..create_parameter_op import do_not_convert_to_tracable_parameter
  45. from ..exc import raise_observed_exception, unimplemented
  46. from ..guards import GuardBuilder, install_guard
  47. from ..mutation_guard import unpatched_nn_module_init
  48. from ..source import (
  49. AttrSource,
  50. GenericAttrSource,
  51. GetItemSource,
  52. TypeMROSource,
  53. TypeSource,
  54. WeakRefCallSource,
  55. )
  56. from ..utils import (
  57. check_unspec_or_constant_args,
  58. cmp_name_to_op_mapping,
  59. identity,
  60. is_tensor_base_attr_getter,
  61. istype,
  62. list_methods,
  63. proxy_args_kwargs,
  64. raise_args_mismatch,
  65. tuple_methods,
  66. )
  67. from .base import (
  68. AsPythonConstantNotImplementedError,
  69. raise_type_error_exc,
  70. VariableTracker,
  71. )
  72. from .constant import CONSTANT_VARIABLE_NONE, ConstantVariable
  73. from .functions import NestedUserFunctionVariable, UserFunctionVariable
  74. from .user_defined import call_random_fn, is_standard_setattr, UserDefinedObjectVariable
  75. if TYPE_CHECKING:
  76. from torch._dynamo.codegen import PyCodegen
  77. from torch._dynamo.symbolic_convert import InstructionTranslator
  78. class NO_SUCH_SUBOBJ:
  79. pass
  80. class SuperVariable(VariableTracker):
  81. _nonvar_fields = {
  82. *VariableTracker._nonvar_fields,
  83. }
  84. def __init__(
  85. self,
  86. typevar: VariableTracker,
  87. objvar: VariableTracker | None = None,
  88. **kwargs: Any,
  89. ) -> None:
  90. super().__init__(**kwargs)
  91. # typevar is the first argument to super(). In the case where no argument
  92. # is provided to super(), it is the __class__ object where
  93. # the super() function is being called
  94. self.typevar = typevar
  95. # objvar here must be an instance or subtype of typevar.
  96. # In the case where super() is called without arguments, it is the first argument
  97. # to the current function where super() is called from (self for regular method,
  98. # cls for a classmethod)
  99. self.objvar = objvar
  100. def reconstruct(self, codegen: "PyCodegen") -> None:
  101. codegen.add_push_null(lambda: codegen(variables.BuiltinVariable(super)))
  102. codegen(self.typevar)
  103. if self.objvar is not None:
  104. codegen(self.objvar)
  105. codegen.extend_output(create_call_function(2, False))
  106. else:
  107. codegen.extend_output(create_call_function(1, False))
  108. def _resolved_getattr_and_source(
  109. self, tx: "InstructionTranslator", name: str
  110. ) -> tuple[Any, AttrSource | None]:
  111. if not self.objvar:
  112. unimplemented(
  113. gb_type="1-arg super not implemented",
  114. context="",
  115. explanation=f"Dynamo failed to trace attribute `{name}` accessed "
  116. f"via `super()` (for type `{self.typevar}` and object `{self.objvar}`) "
  117. "because one-argument of super() is not supported.",
  118. hints=[
  119. "Use two-argument super(type, object_or_type).",
  120. ],
  121. )
  122. assert self.objvar is not None
  123. search_type = self.typevar.as_python_constant()
  124. # The rest of this function does two things:
  125. # - Walk the mro to find where the attribute comes from to be
  126. # able to provide accurate source
  127. # - Call the getattr to get the object
  128. # Find the class object, where the function lives.
  129. # When objvar is "self", use type(self), when objvar is "cls", use it as-is
  130. type_to_use = self.objvar.python_type()
  131. type_to_use_source: Source | None = (
  132. TypeSource(self.objvar.source) if self.objvar.source else None
  133. )
  134. if issubclass(type_to_use, type):
  135. type_to_use = self.objvar.value # type: ignore[attr-defined]
  136. type_to_use_source = self.objvar.source
  137. source = None
  138. search_mro = type_to_use.__mro__
  139. try:
  140. start_index = search_mro.index(search_type) + 1
  141. except ValueError:
  142. # Corner case where the typevar is not in the mro of the objvar
  143. # https://github.com/python/cpython/blob/3.11/Objects/typeobject.c#L8843-L8844
  144. return getattr(super(search_type, type_to_use), name), None
  145. # Implemented based on https://github.com/python/cpython/blob/3.11/Objects/typeobject.c#L8812
  146. # super has its getattro implementation. The key point is that instead of calling getattr, it checks the
  147. # attribute in the class __dict__
  148. for index in range(start_index, len(search_mro)):
  149. # Dont call getattr, just check the __dict__ of the class
  150. if resolved_getattr := search_mro[index].__dict__.get(name, NO_SUCH_SUBOBJ):
  151. if resolved_getattr is not NO_SUCH_SUBOBJ:
  152. # Equivalent of something like type(L['self']).__mro__[1].attr_name
  153. if type_to_use_source:
  154. source = AttrSource(
  155. GetItemSource(TypeMROSource(type_to_use_source), index),
  156. name,
  157. )
  158. return resolved_getattr, source
  159. unimplemented(
  160. gb_type="Unable to resolve super getattr",
  161. context="",
  162. explanation=f"Dynamo failed to trace attribute `{name}` accessed "
  163. f"via `super()` (for type `{self.typevar}` and object `{self.objvar}`) "
  164. "because the resolved attribute type is not supported.",
  165. hints=[
  166. "Ensure the attribute exists in the parent class.",
  167. "Check the arguments passed to `super()`.",
  168. ],
  169. )
  170. def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
  171. # Check if getattr is a constant. If not, delay the actual work by
  172. # wrapping the result in GetAttrVariable. Mostly super is called with a
  173. # method, so most of the work is delayed to call_function.
  174. #
  175. # We could have just implemented a const_getattr. However, super is
  176. # special when it comes to finding sources. Compared to other VTs, super
  177. # requires the attr name to walk the mro and find the actual source (and
  178. # not just AttrSource).
  179. value, source = self._resolved_getattr_and_source(tx, name)
  180. if not variables.ConstantVariable.is_literal(value):
  181. return GetAttrVariable(self, name)
  182. if source:
  183. install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH))
  184. return variables.ConstantVariable.create(value, source=source)
  185. def call_method(
  186. self,
  187. tx: "InstructionTranslator",
  188. name: str,
  189. args: list[VariableTracker],
  190. kwargs: dict[str, VariableTracker],
  191. ) -> VariableTracker:
  192. inner_fn, source = self._resolved_getattr_and_source(tx, name)
  193. assert self.objvar is not None
  194. # This essentially simulates CPython's `super_getattro`:
  195. # https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/typeobject.c#L11138-L11168
  196. # where `inner_fn` is the VT for `res = _super_lookup_descr(...)`.
  197. #
  198. # However, `res`'s type needs to be checked for `tp_descr_get`, and
  199. # applied if it has one. We currently don't have polyfills for all the
  200. # relevant `tp_descr_get`, so we explicitly handle the cases we care
  201. # about here (e.g., note the staticmethod, classmethod cases).
  202. if inner_fn is object.__init__:
  203. return LambdaVariable(identity)
  204. elif inner_fn is torch.nn.Module.__init__:
  205. objvar = self.objvar
  206. from ..side_effects import AttributeMutationNew
  207. if (
  208. isinstance(objvar, variables.UserDefinedObjectVariable)
  209. and isinstance(objvar.mutation_type, AttributeMutationNew)
  210. and not (args or kwargs)
  211. ):
  212. with do_not_convert_to_tracable_parameter():
  213. fn_vt = VariableTracker.build(
  214. tx, unpatched_nn_module_init, source=source
  215. )
  216. return fn_vt.call_function(tx, [self.objvar] + args, kwargs)
  217. else:
  218. unimplemented(
  219. gb_type="Unsupported super().__init__() call",
  220. context=f"call_method {self} {name} {args} {kwargs}",
  221. explanation="Dynamo encountered a super().__init__() call "
  222. f"on {objvar} that resolved to a `torch.nn.Module.__init__()` "
  223. "call that we cannot trace.",
  224. hints=[*graph_break_hints.DIFFICULT],
  225. )
  226. elif (
  227. self.objvar.source
  228. and hasattr(inner_fn, "__name__")
  229. and inner_fn.__name__ == "__new__"
  230. and variables.UserDefinedClassVariable.is_supported_new_method(inner_fn)
  231. ):
  232. user_cls = inner_fn.__self__
  233. if hasattr(user_cls, "__module__") and user_cls.__module__ == "builtins":
  234. user_cls_vt: VariableTracker = variables.BuiltinVariable(user_cls)
  235. else:
  236. assert source is not None
  237. user_cls_source = source.member
  238. user_cls_vt = variables.UserDefinedClassVariable(
  239. user_cls, source=user_cls_source
  240. )
  241. return user_cls_vt.call_method(tx, "__new__", args, kwargs)
  242. elif isinstance(inner_fn, staticmethod) and isinstance(
  243. inner_fn.__func__, types.FunctionType
  244. ):
  245. fn_vt = VariableTracker.build(
  246. tx, inner_fn.__func__, source=source, realize=True
  247. )
  248. return fn_vt.call_function(tx, args, kwargs)
  249. elif isinstance(inner_fn, classmethod) and isinstance(
  250. inner_fn.__func__, types.FunctionType
  251. ):
  252. if isinstance(self.objvar, variables.UserDefinedClassVariable):
  253. # super().classmethod is called from a classmethod itself. So,
  254. # super was converted to super(__class__, cls) in bytecode and
  255. # therefore we have to propagate the cls.
  256. cls_variable = self.objvar
  257. else:
  258. # current function is an instance method, therefore super was
  259. # converted to super(__class__, self). We have to find
  260. # type(self) to bind the cls to the parent classmethod.
  261. # Note that it can't be the self.typevar because __class__ is
  262. # the class where the method is defined, which could be
  263. # different from type(self) with polymorphism.
  264. cls_source = None
  265. if self.objvar.source:
  266. cls_source = TypeSource(self.objvar.source)
  267. cls_variable = VariableTracker.build(
  268. tx,
  269. self.objvar.value_type, # type: ignore[attr-defined]
  270. cls_source,
  271. )
  272. assert source is not None
  273. fn_vt = VariableTracker.build(
  274. tx,
  275. inner_fn.__func__,
  276. source=AttrSource(source, "__func__"),
  277. realize=True,
  278. )
  279. return fn_vt.call_function(tx, [cls_variable, *args], kwargs)
  280. elif isinstance(inner_fn, types.FunctionType):
  281. fn_vt = VariableTracker.build(tx, inner_fn, source=source, realize=True)
  282. return fn_vt.call_function(tx, [self.objvar] + args, kwargs)
  283. elif isinstance(inner_fn, types.MethodType):
  284. return variables.UserMethodVariable(
  285. inner_fn.__func__, self.objvar, source=source
  286. ).call_function(tx, args, kwargs)
  287. elif is_standard_setattr(inner_fn) and isinstance(
  288. self.objvar, UserDefinedObjectVariable
  289. ):
  290. # type: ignore[arg-type]
  291. return self.objvar.method_setattr_standard(tx, *args, **kwargs)
  292. elif inner_fn is object.__delattr__:
  293. attr = args[0]
  294. try:
  295. attr = attr.as_python_constant()
  296. except NotImplementedError as exc:
  297. unimplemented(
  298. gb_type="Non-constant attribute given to `super().__delattr__()`",
  299. context=f"call_method {self} {name}",
  300. explanation="Dynamo requires the attribute name passed to "
  301. "`super().__delattr__(...)` to be a constant (string).",
  302. hints=[
  303. "Ensure the attribute name is a string literal or a constant variable."
  304. ],
  305. from_exc=exc,
  306. )
  307. if not tx.output.side_effects.is_attribute_mutation(self.objvar):
  308. unimplemented(
  309. gb_type="Attempted super().__delattr__() on an object without mutation tracking",
  310. context=f"call_method {self} {name}",
  311. explanation="Dynamo needs to track mutations on an object "
  312. "before `super().__delattr__` can be used on it. But the "
  313. f"object ({self.objvar}) doesn't have attribute mutation "
  314. "tracking enabled.",
  315. hints=[
  316. "Ensure the object is tracked by Dynamo's side effect system.",
  317. *graph_break_hints.DYNAMO_BUG,
  318. ],
  319. )
  320. assert isinstance(attr, str)
  321. tx.output.side_effects.store_attr(
  322. self.objvar, attr, variables.DeletedVariable()
  323. )
  324. return variables.CONSTANT_VARIABLE_NONE
  325. elif (
  326. isinstance(self.objvar, variables.UserDefinedDictVariable)
  327. and inner_fn in self.objvar._dict_methods
  328. ):
  329. return self.objvar._dict_vt.call_method(tx, name, args, kwargs)
  330. elif (
  331. isinstance(self.objvar, variables.UserDefinedSetVariable)
  332. and inner_fn in self.objvar._set_methods
  333. ):
  334. return self.objvar._set_vt.call_method(tx, name, args, kwargs)
  335. elif (
  336. isinstance(self.objvar, variables.UserDefinedTupleVariable)
  337. and inner_fn in tuple_methods
  338. ):
  339. return self.objvar._tuple_vt.call_method(tx, name, args, kwargs)
  340. elif (
  341. isinstance(self.objvar, variables.UserDefinedListVariable)
  342. and inner_fn in list_methods
  343. ):
  344. return self.objvar._list_vt.call_method(tx, name, args, kwargs)
  345. elif inner_fn is object.__getattribute__:
  346. # object.__getattribute__ has no side-effects. We can directly call
  347. # __getattribute__ to access the attribute.
  348. attr_name = args[0].value # type: ignore[attr-defined]
  349. if tx.output.side_effects.has_pending_mutation_of_attr(
  350. self.objvar, attr_name
  351. ):
  352. result = tx.output.side_effects.load_attr(
  353. self.objvar, attr_name, deleted_ok=True
  354. )
  355. if isinstance(result, variables.DeletedVariable):
  356. raise_observed_exception(AttributeError, tx)
  357. return result
  358. attr_value = None
  359. try:
  360. # NB - use object.__getattribute__ to prevent running any user code
  361. # type: ignore[attr-defined]
  362. attr_value = object.__getattribute__(self.objvar.value, attr_name)
  363. except AttributeError:
  364. raise_observed_exception(AttributeError, tx)
  365. attr_source = None
  366. if self.objvar.source is not None:
  367. # setup a object.__getattribute__(self.objvar, name) source
  368. attr_source = GenericAttrSource(self.objvar.source, attr_name)
  369. return VariableTracker.build(tx, attr_value, attr_source)
  370. elif inner_fn is torch._C._disabled_torch_function_impl:
  371. # See `THPModule_disable_torch_function` for the C impl.
  372. # The signature of _disabled_torch_function_impl is similar to
  373. # `__torch_function__`, just without the first `cls` argument:
  374. # * (func, types, args, kwargs)
  375. func = args[0]
  376. # pyrefly: ignore [implicit-any]
  377. tf_kwargs = {}
  378. tf_args = args[2].items # type: ignore[attr-defined]
  379. # type: ignore[attr-defined]
  380. for hash_key_vt, value_vt in args[3].items.items():
  381. key_str = hash_key_vt.vt.as_python_constant()
  382. tf_kwargs[key_str] = value_vt
  383. tx_old = tx.symbolic_torch_function_state.torch_function_subclass_enabled
  384. tx.symbolic_torch_function_state.torch_function_subclass_enabled = False
  385. try:
  386. return func.call_function(tx, tf_args, tf_kwargs)
  387. finally:
  388. tx.symbolic_torch_function_state.torch_function_subclass_enabled = (
  389. tx_old
  390. )
  391. elif (
  392. isinstance(inner_fn, types.MethodDescriptorType)
  393. and inner_fn in trace_rules.get_tensor_method()
  394. ):
  395. # FunctionType but implementation is in C, we support some of these,
  396. # e.g., tensor ops like `torch.Tensor.to`.
  397. fn_var = VariableTracker.build(tx, inner_fn, source, realize=True)
  398. return fn_var.call_function(tx, [self.objvar] + args, kwargs)
  399. unimplemented(
  400. gb_type="Attempted to call a super() attribute that is "
  401. "not a function or method",
  402. context=f"call_method {self} {name}",
  403. explanation="Dynamo does not know how to trace the call "
  404. f"`super().{name}()` because `super().{name}` is not a "
  405. "function or method attribute.",
  406. hints=[
  407. "Ensure the attribute accessed via `super()` is a standard method or function.",
  408. ],
  409. )
  410. class FrameSummaryVariable(VariableTracker):
  411. def __init__(self, frame_summary: traceback.FrameSummary, **kwargs: Any) -> None:
  412. super().__init__(**kwargs)
  413. self.frame_summary = frame_summary
  414. def python_type(self) -> type:
  415. return traceback.FrameSummary
  416. def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
  417. if name == "lineno":
  418. return variables.ConstantVariable.create(self.frame_summary.lineno)
  419. elif name == "filename":
  420. return variables.ConstantVariable.create(self.frame_summary.filename)
  421. elif name == "name":
  422. return variables.ConstantVariable.create(self.frame_summary.name)
  423. elif name == "line":
  424. return variables.ConstantVariable.create(self.frame_summary.line)
  425. return super().var_getattr(tx, name)
  426. class TracebackVariable(VariableTracker):
  427. def __init__(
  428. self,
  429. frame_summary: FrameSummaryVariable,
  430. tb_next: Union["TracebackVariable", ConstantVariable],
  431. **kwargs: Any,
  432. ) -> None:
  433. # The traceback holds four attributes:
  434. # - tb_frame
  435. # - tb_lineno
  436. # - tb_lasti
  437. # - tb_next
  438. super().__init__(**kwargs)
  439. self.frame_summary = frame_summary
  440. # the next traceback in the chain
  441. assert tb_next is not None
  442. self.tb_next = tb_next
  443. @classmethod
  444. def from_frame_summary(
  445. cls,
  446. frame_summary: traceback.FrameSummary,
  447. tb_next: Union["TracebackVariable", ConstantVariable],
  448. ) -> "TracebackVariable":
  449. return cls(FrameSummaryVariable(frame_summary), tb_next=tb_next)
  450. @staticmethod
  451. def is_valid_traceback(obj: VariableTracker) -> bool:
  452. return istype(obj, TracebackVariable) or (
  453. istype(obj, ConstantVariable) and obj.is_constant_none()
  454. )
  455. def extract_tb(self) -> list[traceback.FrameSummary | FrameSummaryVariable]:
  456. if istype(self.tb_next, ConstantVariable):
  457. return [self.frame_summary]
  458. return [self.frame_summary] + self.tb_next.extract_tb()
  459. def has_reference_cycle(self, tb: VariableTracker) -> bool:
  460. # checks if `tb` is in the chain of tb_next starting from `self`
  461. curr_tb: TracebackVariable | ConstantVariable = self
  462. while istype(curr_tb, TracebackVariable):
  463. if curr_tb is tb:
  464. return True
  465. curr_tb = curr_tb.tb_next
  466. return False
  467. def python_type(self) -> type[types.TracebackType]:
  468. return types.TracebackType
  469. def call_setattr(
  470. self,
  471. tx: "InstructionTranslator",
  472. name_var: VariableTracker,
  473. val: VariableTracker,
  474. ) -> VariableTracker:
  475. name = name_var.as_python_constant()
  476. if name == "tb_next":
  477. if not self.is_valid_traceback(val):
  478. raise_observed_exception(TypeError, tx)
  479. assert isinstance(val, (TracebackVariable, ConstantVariable))
  480. if self.has_reference_cycle(val) or (
  481. istype(val, TracebackVariable) and val.has_reference_cycle(self)
  482. ):
  483. raise_observed_exception(ValueError, tx)
  484. self.tb_next = val
  485. return variables.CONSTANT_VARIABLE_NONE
  486. def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
  487. if name == "tb_next":
  488. return self.tb_next
  489. elif name == "tb_lineno":
  490. return self.frame_summary.var_getattr(tx, "lineno")
  491. elif name == "frame_summary":
  492. return self.frame_summary
  493. elif name == "tb_lasti":
  494. unimplemented(
  495. gb_type="traceback.tb_lasti not supported",
  496. context=f"{self} accessing 'tb_lasti'",
  497. explanation="Dynamo does not support accessing the tb_lasti attribute of traceback objects.",
  498. hints=[*graph_break_hints.SUPPORTABLE],
  499. )
  500. return super().var_getattr(tx, name)
  501. def call_method(
  502. self,
  503. tx: "InstructionTranslator",
  504. name: str,
  505. args: list[VariableTracker],
  506. kwargs: dict[str, VariableTracker],
  507. ) -> VariableTracker:
  508. if name == "__eq__":
  509. # Two traceback variables are only equal if they are the same object
  510. return variables.ConstantVariable.create(self is args[0])
  511. elif name == "__setattr__":
  512. return self.call_setattr(tx, *args)
  513. return super().call_method(tx, name, args, kwargs)
  514. class ExceptionVariable(VariableTracker):
  515. # The ExceptionVariable corresponds to the BaseException class in Python
  516. def __init__(
  517. self,
  518. exc_type: Any,
  519. args: tuple[VariableTracker, ...],
  520. init_kwargs: dict[str, VariableTracker] | None = None,
  521. source: Source | None = None,
  522. mutation_type: MutationType | None = None,
  523. ) -> None:
  524. super().__init__(source=source, mutation_type=mutation_type)
  525. self.exc_type = exc_type
  526. self.args = args
  527. if init_kwargs:
  528. unimplemented(
  529. gb_type="Keyword args passed to exception constructor",
  530. context=f"{self} with kwargs {init_kwargs}",
  531. explanation="Dynamo does not know how to handle keyword args passed to an exception constructor",
  532. hints=[*graph_break_hints.SUPPORTABLE],
  533. )
  534. # When raising a new exception while another exception is already being
  535. # handled, the new exception's __context__ attribute is automatically
  536. # set to the handled exception.
  537. self.__context__: VariableTracker = CONSTANT_VARIABLE_NONE
  538. # Set when user raised an exception from another:
  539. # raise ... from ...
  540. self.__cause__: VariableTracker = CONSTANT_VARIABLE_NONE
  541. # Boolean flag that controls whether the __context__ attribute is set
  542. self.__suppress_context__: VariableTracker = ConstantVariable(False)
  543. # Contains the call stack where the exception was raised.
  544. self.__traceback__: VariableTracker = CONSTANT_VARIABLE_NONE
  545. # The user stack at the time this exception was first raised.
  546. # Used to preserve the original exception location when re-raising.
  547. self.python_stack: traceback.StackSummary | None = None
  548. def set_context(self, context: VariableTracker) -> None:
  549. self.__context__ = context
  550. def reconstruct(self, codegen: "PyCodegen") -> None:
  551. codegen.add_push_null(
  552. lambda: codegen.load_import_from("builtins", self.exc_type.__name__)
  553. )
  554. codegen.foreach(self.args)
  555. codegen.call_function(len(self.args), False)
  556. def codegen_attr(name: str) -> None:
  557. attr = getattr(self, name)
  558. if istype(attr, ConstantVariable):
  559. assert attr.value in (True, False, None), attr
  560. else:
  561. codegen.dup_top()
  562. codegen(attr)
  563. codegen.extend_output(codegen.rot_n(2))
  564. codegen.store_attr(name)
  565. codegen_attr("__context__")
  566. codegen_attr("__cause__")
  567. codegen_attr("__suppress_context__")
  568. def python_type(self) -> type:
  569. return self.exc_type
  570. def call_setattr(
  571. self,
  572. tx: "InstructionTranslator",
  573. name_var: VariableTracker,
  574. val: VariableTracker,
  575. ) -> VariableTracker:
  576. def raise_error(msg: str) -> NoReturn:
  577. raise_observed_exception(TypeError, tx, args=[ConstantVariable(msg)])
  578. name = name_var.as_python_constant()
  579. if name == "__context__":
  580. # Constant can be either an Exceptior or None
  581. assert isinstance(val, (ExceptionVariable, ConstantVariable))
  582. self.set_context(val)
  583. elif name == "__cause__":
  584. if val.is_constant_none() or isinstance(
  585. val,
  586. (
  587. variables.BuiltinVariable,
  588. variables.ExceptionVariable,
  589. variables.UserDefinedExceptionClassVariable,
  590. variables.UserDefinedExceptionObjectVariable,
  591. ),
  592. ):
  593. self.__cause__ = val
  594. self.__suppress_context__ = variables.ConstantVariable(True)
  595. else:
  596. raise_error("exception cause must be None or derive from BaseException")
  597. elif name == "__suppress_context__":
  598. if val.is_constant_match(True, False):
  599. self.__suppress_context__ = val
  600. else:
  601. raise_error("exception cause must be None or derive from BaseException")
  602. elif name == "__traceback__":
  603. if not TracebackVariable.is_valid_traceback(val):
  604. raise_observed_exception(
  605. TypeError,
  606. tx,
  607. args=[
  608. ConstantVariable.create(
  609. "__traceback__ must be a traceback object or None"
  610. )
  611. ],
  612. )
  613. self.__traceback__ = val
  614. else:
  615. unimplemented(
  616. gb_type="Unsupported attribute assignment on Exception object",
  617. context=f"call_setattr {self} {name}",
  618. explanation="Dynamo does not support setting the attribute "
  619. f"'{name}' on tracked exception objects. Only `__context__`, "
  620. "`__cause__`, `__suppress_context__`, and `__traceback__` are supported.",
  621. hints=[*graph_break_hints.SUPPORTABLE],
  622. )
  623. return variables.CONSTANT_VARIABLE_NONE
  624. def call_method(
  625. self,
  626. tx: "InstructionTranslator",
  627. name: str,
  628. args: list[VariableTracker],
  629. kwargs: dict[str, VariableTracker],
  630. ) -> VariableTracker:
  631. if name == "__setattr__":
  632. return self.call_setattr(tx, *args)
  633. elif name == "with_traceback":
  634. [tb] = args
  635. self.call_setattr(tx, ConstantVariable("__traceback__"), tb)
  636. return self
  637. else:
  638. return super().call_method(tx, name, args, kwargs)
  639. def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
  640. if name == "__context__":
  641. return self.__context__
  642. elif name == "__cause__":
  643. return self.__cause__
  644. elif name == "__suppress_context__":
  645. return self.__suppress_context__
  646. elif name == "__traceback__":
  647. return self.__traceback__
  648. elif name == "args":
  649. return variables.ListVariable(list(self.args), source=self.source)
  650. return super().var_getattr(tx, name)
  651. def __str__(self) -> str:
  652. return f"{self.__class__.__name__}({self.exc_type})"
  653. __repr__ = __str__
  654. class UnknownVariable(VariableTracker):
  655. """
  656. It could be anything!
  657. """
  658. class DelayGraphBreakVariable(UnknownVariable):
  659. """
  660. Used to insert a dummy variable in the stack to do the graph break at CALL_FUNCTION.
  661. """
  662. def __init__(self, msg: str | None = None, **kwargs: Any) -> None:
  663. super().__init__(**kwargs)
  664. self.msg = msg
  665. def call_function(
  666. self,
  667. tx: "InstructionTranslator",
  668. args: Sequence[VariableTracker],
  669. kwargs: dict[str, VariableTracker],
  670. ) -> VariableTracker:
  671. name = "" if self.source is None else self.source.name
  672. unimplemented(
  673. gb_type="Unsupported function call (delayed)",
  674. context=f"source: {self.source}",
  675. explanation="Dynamo determined that a graph break should occur "
  676. f"when calling `{name}`. Reason: {self.msg}",
  677. hints=[],
  678. )
  679. class ComptimeVariable(VariableTracker):
  680. """
  681. This variable is special, it lets you execute arbitrary code at
  682. Dynamo compile time
  683. """
  684. def reconstruct(self, codegen: "PyCodegen") -> None:
  685. raise NotImplementedError("comptime is special form")
  686. def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
  687. from ..comptime import comptime
  688. assert self.source is not None
  689. # To support the comptime.print_graph convenience accessors
  690. return VariableTracker.build(
  691. tx, getattr(comptime, name), source=AttrSource(self.source, name)
  692. )
  693. def call_function(
  694. self,
  695. tx: "InstructionTranslator",
  696. args: Sequence[VariableTracker],
  697. kwargs: dict[str, VariableTracker],
  698. ) -> VariableTracker:
  699. from ..comptime import ComptimeContext
  700. # TODO: support an expression form as well
  701. # Second argument is runtime lambda, ignored
  702. if kwargs or len(args) > 2:
  703. raise_args_mismatch(
  704. tx,
  705. "comptime()",
  706. "at most 2 args and 0 kwargs",
  707. f"{len(args)} args and {len(kwargs)} kwargs",
  708. )
  709. fn = args[0]
  710. if isinstance(fn, UserFunctionVariable):
  711. fn.get_function()(ComptimeContext(tx))
  712. elif isinstance(fn, NestedUserFunctionVariable):
  713. # We have to manually bind the freevars ourselves
  714. code = fn.get_code()
  715. if fn.closure:
  716. raise_type_error_exc(
  717. tx,
  718. f"comptime function must not have free variables, but these variables were free: {code.co_freevars}",
  719. )
  720. func = types.FunctionType(
  721. code,
  722. fn.f_globals,
  723. fn.fn_name.as_python_constant(),
  724. # type: ignore[attr-defined]
  725. tuple(fn.defaults.items) if fn.defaults else None,
  726. # We could automatically promote free variables into
  727. # ComptimeVar but this is confusing if you access
  728. # a free variable that we actually DO have the runtime
  729. # value for
  730. # tuple(make_cell(ComptimeVar(i)) for i in fn.closure.items)
  731. (),
  732. )
  733. func(ComptimeContext(tx))
  734. else:
  735. raise RuntimeError(f"unsupported argument to comptime: {type(fn)}")
  736. return variables.CONSTANT_VARIABLE_NONE
  737. class CellVariable(VariableTracker):
  738. # If the cell existed before Dynamo tracing started, this will be the
  739. # VariableTracker that represents the cell content.
  740. #
  741. # Note that all mutation to the cell (i.e., its content) will be buffered in
  742. # SideEffects, rather than being reflected here. One can think of
  743. # `CellVariable` as a special case for `UserDefinedObjectVariable`.
  744. pre_existing_contents: VariableTracker | None
  745. # This is set when this cell can be referenced via `LOAD/STORE_DEREF` in the
  746. # root frame via this name (e.g., the name is in `co_cellvars/co_freevars`).
  747. local_name: str | None = None
  748. def __init__(
  749. self, pre_existing_contents: VariableTracker | None = None, **kwargs: Any
  750. ) -> None:
  751. super().__init__(**kwargs)
  752. self.pre_existing_contents = pre_existing_contents
  753. class NewGlobalVariable(VariableTracker):
  754. def __init__(self, **kwargs: Any) -> None:
  755. super().__init__(**kwargs)
  756. def produce_trampoline_autograd_apply(fn_cls: Any) -> Callable[..., Any]:
  757. def trampoline_autograd_apply(*args: Any, **kwargs: Any) -> Any:
  758. return fn_cls.apply(*args, **kwargs)
  759. # type: ignore[attr-defined]
  760. trampoline_autograd_apply._origin = produce_trampoline_autograd_apply
  761. return trampoline_autograd_apply
  762. class AutogradFunctionVariable(VariableTracker):
  763. """represents a torch.autograd.Function subclass"""
  764. _nonvar_fields = {
  765. "fn_cls",
  766. *VariableTracker._nonvar_fields,
  767. }
  768. def __init__(self, fn_cls: Any, **kwargs: Any) -> None:
  769. super().__init__(**kwargs)
  770. self.fn_cls = fn_cls
  771. def call_apply(
  772. self,
  773. tx: "InstructionTranslator",
  774. args: list[VariableTracker],
  775. kwargs: dict[str, VariableTracker],
  776. ) -> VariableTracker:
  777. requires_grad = False
  778. def visit(vt: VariableTracker) -> None:
  779. nonlocal requires_grad
  780. if vt.is_tensor():
  781. # type: ignore[attr-defined]
  782. if vt.requires_grad is not False:
  783. requires_grad = True
  784. if isinstance(vt, variables.NNModuleVariable):
  785. if vt.is_training(tx):
  786. requires_grad = True
  787. VariableTracker.visit(visit, (args, kwargs))
  788. if requires_grad and torch.is_grad_enabled():
  789. source = self.source
  790. from torch._functorch.autograd_function import (
  791. autograd_function_forward_rewritten,
  792. )
  793. from torch.autograd.function import _is_setup_context_defined
  794. forward_fn = self.fn_cls.forward
  795. is_setup_ctx_defined = _is_setup_context_defined(self.fn_cls.setup_context)
  796. if is_setup_ctx_defined:
  797. # If setup_context is defined, we generate a new forward function which includes
  798. # the original forward and setup_context function, and trace the new forward function.
  799. forward_fn = autograd_function_forward_rewritten(
  800. self.fn_cls.forward, self.fn_cls.setup_context
  801. )
  802. # The forward points to a new function now, so we can't use the
  803. # old source. Later on, we guard specifically on
  804. # is_setup_ctx_defined
  805. source = None
  806. vjp_fn = self.fn_cls.vjp # type: ignore[attr-defined]
  807. if vjp_fn is not torch.autograd.Function.vjp:
  808. unimplemented(
  809. gb_type="Unsupported custom vjp",
  810. context=f"call_apply {self} {args} {kwargs}",
  811. explanation="Dynamo does not support tracing "
  812. "`torch.autograd.Function` subclasses that define "
  813. "a custom `vjp` method.",
  814. hints=[
  815. "Remove the custom `vjp` method if possible.",
  816. "Use standard `backward` instead if applicable.",
  817. *graph_break_hints.SUPPORTABLE,
  818. ],
  819. )
  820. jvp_fn = self.fn_cls.jvp # type: ignore[attr-defined]
  821. if jvp_fn is not torch.autograd.Function.jvp:
  822. unimplemented(
  823. gb_type="Unsupported custom jvp",
  824. context=f"call_apply {self} {args} {kwargs}",
  825. explanation="Dynamo does not support tracing "
  826. "`torch.autograd.Function` subclasses that define "
  827. "a custom `jvp` method.",
  828. hints=[
  829. "Remove the custom `jvp` method if possible.",
  830. *graph_break_hints.SUPPORTABLE,
  831. ],
  832. )
  833. from .higher_order_ops import AutogradFunctionApplyVariable
  834. if source is None and not is_setup_ctx_defined:
  835. source = AttrSource(
  836. tx.import_source(self.fn_cls.__module__), self.fn_cls.__name__
  837. )
  838. apply_source = source and AttrSource(source, member="apply")
  839. val = AutogradFunctionApplyVariable(
  840. forward_fn,
  841. self.fn_cls.backward,
  842. source,
  843. source=apply_source,
  844. ).call_function(tx, args, kwargs)
  845. if self.source and is_setup_ctx_defined:
  846. fwd_src = AttrSource(self.source, "forward")
  847. install_guard(fwd_src.make_guard(GuardBuilder.CLOSURE_MATCH))
  848. setup_ctx_src = AttrSource(self.source, "setup_context")
  849. install_guard(setup_ctx_src.make_guard(GuardBuilder.CLOSURE_MATCH))
  850. return val
  851. if self.source:
  852. source = AttrSource(self.source, "forward")
  853. else:
  854. source = None
  855. fn = self.fn_cls.forward
  856. ctx = AutogradFunctionContextVariable.create(tx, args, kwargs)
  857. args = [ctx, *args]
  858. if isinstance(fn, types.FunctionType):
  859. sig = inspect.signature(fn)
  860. if len(args) - 1 == len(sig.parameters):
  861. args = args[1:] # Don't use context
  862. fn_vt = VariableTracker.build(tx, fn, source=source, realize=True)
  863. return fn_vt.call_function(tx, args, kwargs)
  864. elif isinstance(fn, types.MethodType):
  865. return variables.UserMethodVariable(
  866. fn.__func__,
  867. variables.UserDefinedClassVariable(self.fn_cls),
  868. source=source,
  869. ).call_function(tx, args, kwargs)
  870. else:
  871. unimplemented(
  872. gb_type="Non-function or method in subclass of torch.autograd.Function",
  873. context=f"call_apply {self} {args} {kwargs}",
  874. explanation="Dynamo requires the `forward` attribute of a "
  875. "`torch.autograd.Function` subclass to be a standard Python "
  876. f"function or method. Found type `{type(fn).__name__}` instead.",
  877. hints=[
  878. "Ensure the `forward` method is defined as a regular "
  879. "function or instance method."
  880. ],
  881. )
  882. def call_backward(
  883. self,
  884. tx: "InstructionTranslator",
  885. args: list[VariableTracker],
  886. kwargs: dict[str, VariableTracker],
  887. ) -> VariableTracker:
  888. fn = self.fn_cls.backward
  889. # type: ignore[attr-defined]
  890. assert type(args[0].value) is torch._dynamo.external_utils.FakeBackwardCFunction
  891. assert isinstance(fn, types.FunctionType)
  892. assert self.source is not None
  893. fn_source = AttrSource(self.source, "backward")
  894. fn_vt = VariableTracker.build(tx, fn, source=fn_source, realize=True)
  895. return fn_vt.call_function(tx, args, kwargs)
  896. def call_function(
  897. self,
  898. tx: "InstructionTranslator",
  899. args: Sequence[VariableTracker],
  900. kwargs: dict[str, VariableTracker],
  901. ) -> "AutogradFunctionVariable":
  902. return AutogradFunctionVariable(self.fn_cls)
  903. def call_method(
  904. self,
  905. tx: "InstructionTranslator",
  906. name: str,
  907. args: list[VariableTracker],
  908. kwargs: dict[str, VariableTracker],
  909. ) -> VariableTracker:
  910. from .builder import wrap_fx_proxy
  911. if name == "apply":
  912. if trace_rules.is_callable_allowed(self.fn_cls):
  913. trampoline_autograd_apply = produce_trampoline_autograd_apply(
  914. self.fn_cls
  915. )
  916. return wrap_fx_proxy(
  917. tx=tx,
  918. proxy=tx.output.create_proxy(
  919. "call_function",
  920. trampoline_autograd_apply,
  921. *proxy_args_kwargs(args, kwargs),
  922. ),
  923. )
  924. else:
  925. return self.call_apply(tx, args, kwargs)
  926. elif name == "backward":
  927. return self.call_backward(tx, args, kwargs)
  928. else:
  929. source = AttrSource(self.source, name) if self.source is not None else None
  930. try:
  931. obj = inspect.getattr_static(self.fn_cls, name)
  932. except AttributeError:
  933. obj = None
  934. if isinstance(obj, staticmethod):
  935. func = obj.__get__(self.fn_cls)
  936. traced = trace_rules.lookup(func)
  937. assert traced is not None
  938. if source is not None:
  939. return (
  940. # type: ignore[attr-defined]
  941. traced.create_with_source(func, source=source).call_function(
  942. tx, args, kwargs
  943. )
  944. )
  945. else:
  946. # type: ignore[misc]
  947. return traced(func).call_function(tx, args, kwargs)
  948. elif isinstance(obj, classmethod):
  949. return variables.UserMethodVariable(
  950. obj.__func__, self, source=source
  951. ).call_function(tx, args, kwargs)
  952. else:
  953. unimplemented(
  954. gb_type="Unsupported autograd.Function method",
  955. context=f"call_method {self} {name}",
  956. explanation="Dynamo does not support calling the method "
  957. f"`{name}` directly on the `torch.autograd.Function` "
  958. "instance. Supported methods include `apply`, `backward`, "
  959. "static methods, and class methods.",
  960. hints=[
  961. "Ensure the method is decorated with `@staticmethod` "
  962. "or `@classmethod` if it's meant to be called on the class.",
  963. ],
  964. )
  965. @dataclasses.dataclass
  966. class SavedTensorBox:
  967. tensors: list[VariableTracker] = dataclasses.field(default_factory=list)
  968. class AutogradFunctionContextVariable(UserDefinedObjectVariable):
  969. """
  970. Tracks an autograd.Function() context using mutation tracking in side_effects.py
  971. """
  972. _nonvar_fields = {
  973. "proxy",
  974. "inference",
  975. "saved_tensors",
  976. *UserDefinedObjectVariable._nonvar_fields,
  977. }
  978. def __init__(
  979. self,
  980. value: Any,
  981. value_type: type | None = None,
  982. inference: bool = False,
  983. saved_tensors: Any | None = None,
  984. needs_input_grad: tuple[bool, ...] | None = None,
  985. non_differentiable: Any | None = None,
  986. **kwargs: Any,
  987. ) -> None:
  988. super().__init__(value=value, value_type=value_type, **kwargs)
  989. self.inference = inference
  990. self.saved_tensors = saved_tensors
  991. self.needs_input_grad = needs_input_grad
  992. self.non_differentiable = non_differentiable
  993. @staticmethod
  994. def create(
  995. tx: "InstructionTranslator",
  996. args: Sequence[VariableTracker] | None = None,
  997. kwargs: dict[str, VariableTracker] | None = None,
  998. ) -> VariableTracker:
  999. needs_input_grad = None
  1000. if args and not kwargs:
  1001. # type: ignore[attr-defined]
  1002. needs_input_grad = tuple(x.is_tensor() and x.requires_grad for x in args)
  1003. out = tx.output.side_effects.track_object_new(
  1004. None,
  1005. torch.autograd.function.FunctionCtx,
  1006. functools.partial(
  1007. AutogradFunctionContextVariable,
  1008. inference=True,
  1009. saved_tensors=SavedTensorBox(),
  1010. needs_input_grad=needs_input_grad,
  1011. ),
  1012. {},
  1013. )
  1014. return out
  1015. def as_proxy(self) -> Any:
  1016. # type: ignore[attr-defined]
  1017. if self.proxy is None:
  1018. unimplemented(
  1019. gb_type="proxy not set",
  1020. context=f"as_proxy {self}",
  1021. explanation="Dynamo requires the autograd.Function context "
  1022. "to be initialized with a proxy.",
  1023. hints=[*graph_break_hints.DYNAMO_BUG],
  1024. )
  1025. # type: ignore[attr-defined]
  1026. return self.proxy
  1027. def call_method(
  1028. self,
  1029. tx: "InstructionTranslator",
  1030. name: str,
  1031. args: list[VariableTracker],
  1032. kwargs: dict[str, VariableTracker],
  1033. ) -> VariableTracker:
  1034. if name == "__setattr__":
  1035. return super().call_method(tx, name, args, kwargs)
  1036. elif name == "mark_non_differentiable":
  1037. if kwargs:
  1038. raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs")
  1039. self.non_differentiable = proxy_args_kwargs(args, {})[0]
  1040. return variables.CONSTANT_VARIABLE_NONE
  1041. if name != "save_for_backward":
  1042. unimplemented(
  1043. gb_type="Unsupported autograd.Function context method",
  1044. context=f"call_method {self} {name}",
  1045. explanation="Dynamo does not support calling the method "
  1046. f"`{name}` on `autograd.Function` context objects. Supported "
  1047. "methods are `__setattr__`, `save_for_backward` and "
  1048. "`mark_non_differentiable`.",
  1049. hints=[*graph_break_hints.SUPPORTABLE],
  1050. )
  1051. if self.saved_tensors is None:
  1052. unimplemented(
  1053. gb_type="Unsupported autograd.Function context `save_for_backward`",
  1054. context=f"call_method {self} {name}",
  1055. explanation="Dynamo requires the `saved_tensors` attribute "
  1056. "to be initialized on the `autograd.Function` context object.",
  1057. hints=[
  1058. "Ensure that the `saved_tensors` attribute is properly "
  1059. "initialized before calling `save_for_backward`. "
  1060. "`save_for_backward` only supported on a newly constructed `torch.autograd.function.FunctionCtx`.",
  1061. ],
  1062. )
  1063. assert self.saved_tensors is not None
  1064. if not self.inference:
  1065. if kwargs or not self.source:
  1066. raise_type_error_exc(
  1067. tx, "save_for_backward() requires a source and no keyword arguments"
  1068. )
  1069. tx.output.side_effects.track_save_for_backward(self, args)
  1070. # In eager mode, multiple calls to .save_for_backward() will overwrite previous calls.
  1071. if len(self.saved_tensors.tensors) > 0:
  1072. # pyrefly: ignore [implicit-any]
  1073. self.saved_tensors.tensors = []
  1074. for arg in args:
  1075. self.saved_tensors.tensors.append(arg)
  1076. return variables.CONSTANT_VARIABLE_NONE
  1077. def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
  1078. if name in ["save_for_backward", "mark_non_differentiable"]:
  1079. return LambdaVariable(
  1080. lambda *args, **kwargs: self.call_method(tx, name, list(args), kwargs)
  1081. )
  1082. if name == "saved_tensors" and self.saved_tensors is not None:
  1083. return variables.TupleVariable(list(self.saved_tensors.tensors))
  1084. if name == "needs_input_grad":
  1085. if self.needs_input_grad is not None:
  1086. return variables.ConstantVariable.create(self.needs_input_grad)
  1087. if self.source:
  1088. source = AttrSource(self.source, "needs_input_grad")
  1089. # type: ignore[attr-defined]
  1090. return VariableTracker.build(tx, self.value.needs_input_grad, source)
  1091. return super().var_getattr(tx, name)
  1092. class AutogradEngineVariable(UserDefinedObjectVariable):
  1093. """
  1094. Represents a torch._C._ImperativeEngine instance.
  1095. """
  1096. def __init__(
  1097. self,
  1098. value: torch._C._ImperativeEngine,
  1099. value_type: type[torch._C._ImperativeEngine] | None = None,
  1100. **kwargs: Any,
  1101. ) -> None:
  1102. super().__init__(value=value, value_type=value_type, **kwargs)
  1103. def call_method(
  1104. self,
  1105. tx: "InstructionTranslator",
  1106. name: str,
  1107. args: list[VariableTracker],
  1108. kwargs: dict[str, VariableTracker],
  1109. ) -> VariableTracker:
  1110. if name == "queue_callback":
  1111. if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
  1112. assert tx.one_graph or tx.error_on_graph_break, (
  1113. "queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True"
  1114. )
  1115. # queue_callback is a method-wrapper, no need to insert a guard.
  1116. fn_vt = VariableTracker.build(
  1117. tx,
  1118. torch._dynamo.external_utils.FakeCompiledAutogradEngine.queue_callback,
  1119. )
  1120. return fn_vt.call_function(
  1121. tx,
  1122. (tx.output.side_effects.get_ca_final_callbacks_var(), *args),
  1123. kwargs,
  1124. )
  1125. else:
  1126. unimplemented(
  1127. gb_type="Unsupported torch._C._ImperativeEngine.queue_callback()",
  1128. context=f"call_method {self} {name}",
  1129. explanation="queue_callback() is only supported when "
  1130. "Compiled Autograd is enabled with fullgraph=True.",
  1131. hints=[],
  1132. )
  1133. else:
  1134. unimplemented(
  1135. gb_type="Unsupported torch._C._ImperativeEngine method",
  1136. context=f"call_method {self} {name}",
  1137. explanation="Dynamo only supports the `queue_callback` method "
  1138. f"on a torch._C._ImperativeEngine instance, but found: `{name}`.",
  1139. hints=[],
  1140. )
  1141. class LambdaVariable(VariableTracker):
  1142. # TODO: change to Ts = TypeVarTuple("Ts") for py 3.11+
  1143. def __init__(self, fn: Callable[..., VariableTracker], **kwargs: Any) -> None:
  1144. super().__init__(**kwargs)
  1145. self.fn = fn
  1146. def call_function(
  1147. self,
  1148. tx: "InstructionTranslator",
  1149. args: Sequence[VariableTracker],
  1150. kwargs: dict[str, VariableTracker],
  1151. ) -> VariableTracker:
  1152. return self.fn(*args, **kwargs)
  1153. class GetAttrVariable(VariableTracker):
  1154. _nonvar_fields = {
  1155. "name",
  1156. "py_type",
  1157. *VariableTracker._nonvar_fields,
  1158. }
  1159. def __init__(
  1160. self,
  1161. obj: VariableTracker,
  1162. name: str,
  1163. py_type: type | None = None,
  1164. **kwargs: Any,
  1165. ) -> None:
  1166. super().__init__(**kwargs)
  1167. assert isinstance(obj, VariableTracker)
  1168. assert isinstance(name, str)
  1169. self.obj = obj
  1170. self.name = name
  1171. self.py_type = py_type # In some cases we know the type (ex. tensor methods)
  1172. def python_type(self) -> type:
  1173. if self.py_type is not None:
  1174. return self.py_type
  1175. else:
  1176. return super().python_type()
  1177. def __repr__(self) -> str:
  1178. return f"{self.__class__.__name__}({self.obj}, {self.name})"
  1179. @staticmethod
  1180. def create_getattr_proxy(base_proxy: torch.fx.Proxy, attr: str) -> Any:
  1181. return getattr(base_proxy, attr)
  1182. def as_proxy(self) -> Any:
  1183. return GetAttrVariable.create_getattr_proxy(self.obj.as_proxy(), self.name)
  1184. def as_python_constant(self) -> Any:
  1185. constant = self.obj.as_python_constant()
  1186. try:
  1187. return getattr(constant, self.name)
  1188. except AttributeError:
  1189. raise NotImplementedError(f"{self} is not a constant") from None
  1190. def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any:
  1191. if not isinstance(self.obj, variables.NNModuleVariable):
  1192. raise NotImplementedError
  1193. step1 = tx.output.get_submodule(self.obj.module_key)
  1194. if self.name not in step1.__dict__:
  1195. raise NotImplementedError
  1196. step2 = inspect.getattr_static(step1, self.name)
  1197. if name not in step2.__dict__:
  1198. raise NotImplementedError
  1199. return inspect.getattr_static(step2, name)
  1200. def reconstruct(self, codegen: "PyCodegen") -> None:
  1201. codegen(self.obj)
  1202. codegen.extend_output(codegen.create_load_attrs(self.name))
  1203. def call_function(
  1204. self,
  1205. tx: "InstructionTranslator",
  1206. args: Sequence[VariableTracker],
  1207. kwargs: dict[str, VariableTracker],
  1208. ) -> VariableTracker:
  1209. return self.obj.call_method(tx, self.name, list(args), kwargs)
  1210. def call_method(
  1211. self,
  1212. tx: "InstructionTranslator",
  1213. name: str,
  1214. args: list[VariableTracker],
  1215. kwargs: dict[str, VariableTracker],
  1216. ) -> VariableTracker:
  1217. if (
  1218. name in ("__getitem__", "get")
  1219. and self.name == "__dict__"
  1220. and not kwargs
  1221. and args[0].is_python_constant()
  1222. and isinstance(
  1223. self.obj,
  1224. (
  1225. variables.UserDefinedObjectVariable,
  1226. variables.NNModuleVariable,
  1227. variables.UserDefinedClassVariable,
  1228. ),
  1229. )
  1230. ):
  1231. obj = self.obj
  1232. key = args[0].as_python_constant()
  1233. if obj.has_key_in_generic_dict(tx, key):
  1234. # redirect to var_getattr on the original obj
  1235. return obj.var_getattr(tx, key)
  1236. # Return the default value for get
  1237. if name == "get":
  1238. if len(args) == 2:
  1239. return args[1]
  1240. else:
  1241. return variables.CONSTANT_VARIABLE_NONE
  1242. elif (
  1243. name == "__contains__"
  1244. and self.name == "__dict__"
  1245. and len(args) == 1
  1246. and args[0].is_python_constant()
  1247. and not kwargs
  1248. and isinstance(
  1249. self.obj,
  1250. (
  1251. variables.UserDefinedObjectVariable,
  1252. variables.NNModuleVariable,
  1253. variables.UserDefinedClassVariable,
  1254. ),
  1255. )
  1256. ):
  1257. obj = self.obj
  1258. key = args[0].as_python_constant()
  1259. if obj.has_key_in_generic_dict(tx, key):
  1260. return variables.ConstantVariable(True)
  1261. else:
  1262. return variables.ConstantVariable(False)
  1263. elif name == "__setitem__" and self.name == "__dict__" and not kwargs:
  1264. if isinstance(self.obj, variables.UserDefinedObjectVariable):
  1265. # Bypass any custom setattr as we are updating the `__dict__` itself
  1266. return self.obj.method_setattr_standard(
  1267. tx, args[0], args[1], directly_update_dict=True
  1268. )
  1269. if isinstance(self.obj, variables.NNModuleVariable):
  1270. # This matches how `setattr` is handled for NNModuleVariable
  1271. self.obj.convert_to_unspecialized(tx)
  1272. return super().call_method(tx, name, args, kwargs)
  1273. def get_forwarded_dict(self, tx: "InstructionTranslator") -> VariableTracker:
  1274. assert (
  1275. self.name == "__dict__"
  1276. and isinstance(self.obj, variables.UserDefinedClassVariable)
  1277. and not tx.output.side_effects.has_pending_mutation(self.obj)
  1278. )
  1279. self.obj.ban_mutation = True
  1280. return VariableTracker.build(tx, self.obj.value.__dict__, self.source)
  1281. class MethodWrapperVariable(VariableTracker):
  1282. def __init__(self, method_wrapper: types.MethodWrapperType, **kwargs: Any) -> None:
  1283. super().__init__(**kwargs)
  1284. self.method_wrapper = method_wrapper
  1285. def call_function(
  1286. self,
  1287. tx: "InstructionTranslator",
  1288. args: Sequence[VariableTracker],
  1289. kwargs: dict[str, VariableTracker],
  1290. ) -> VariableTracker:
  1291. if is_tensor_base_attr_getter(self.method_wrapper) and isinstance(
  1292. args[0], variables.TensorVariable
  1293. ):
  1294. if not (len(args) == 1 and len(kwargs) == 0):
  1295. raise_type_error_exc(
  1296. tx, "tensor attribute getter takes exactly one argument"
  1297. )
  1298. # type: ignore[arg-type, attr-defined]
  1299. return args[0].var_getattr(tx, self.method_wrapper.__self__.__name__)
  1300. # method-wrapper variables are common in __init__ calls. For example,
  1301. # str("foo").__init__ is a method-wrapper. These method wrappers point
  1302. # to C functions. Here we intercept if these method-wrappers are from
  1303. # builtins and then call the function counterpart directly by obtaining
  1304. # the self object.
  1305. self_obj = self.method_wrapper.__self__
  1306. wrapper_name = self.method_wrapper.__name__
  1307. # TODO(dynamo-team) - We can perhaps expand the scope to more names and
  1308. # more builtins.
  1309. if wrapper_name == "__init__":
  1310. fn_obj = type(self_obj).__init__
  1311. if fn_obj is object.__init__:
  1312. return variables.BuiltinVariable(object).call_method(
  1313. tx,
  1314. wrapper_name,
  1315. # type: ignore[arg-type, list-item]
  1316. [self_obj, *args],
  1317. kwargs,
  1318. )
  1319. elif (
  1320. sys.version_info >= (3, 14)
  1321. # for some reason, even if the below check passes,
  1322. # self.method_wrapper may not be the same as type.__dict__["__annotations__"].__get__
  1323. and self_obj is type.__dict__["__annotations__"]
  1324. and wrapper_name == "__get__"
  1325. ):
  1326. from .builder import SourcelessBuilder
  1327. if len(args) == 1 and not kwargs:
  1328. try:
  1329. return SourcelessBuilder.create(
  1330. tx, self.method_wrapper(args[0].as_python_constant())
  1331. )
  1332. except AttributeError:
  1333. raise_observed_exception(AttributeError, tx)
  1334. except AsPythonConstantNotImplementedError:
  1335. pass
  1336. unimplemented(
  1337. gb_type="unsupported type.__dict__['__annotations__'].__get__ call",
  1338. context=f"call_function {self}, args: {args}, kwargs: {kwargs}",
  1339. explanation="`torch.compile` only supports calling type.__dict__['__annotations__'].__get__ "
  1340. "on a single constant argument (i.e. a type).",
  1341. hints=[
  1342. "Make sure your call to type.__dict__['__annotations__'] only has "
  1343. "one positional argument (no keyword arguments).",
  1344. "Make sure the argument to type.__dict__['__annotations__'] is a constant "
  1345. "(i.e. type). For example, `object`, `int`, `MyCustomClass`.",
  1346. *graph_break_hints.SUPPORTABLE,
  1347. ],
  1348. )
  1349. elif (self_obj is type.__dict__["__mro__"] and wrapper_name == "__get__") or (
  1350. self_obj is type.__dict__["__dict__"] and wrapper_name == "__get__"
  1351. ):
  1352. from .builder import SourcelessBuilder
  1353. if len(args) == 1 and not kwargs:
  1354. try:
  1355. return SourcelessBuilder.create(
  1356. tx, self.method_wrapper(args[0].as_python_constant())
  1357. )
  1358. except AsPythonConstantNotImplementedError:
  1359. pass
  1360. attr_name = (
  1361. "__mro__" if self_obj is type.__dict__["__mro__"] else "__dict__"
  1362. )
  1363. unimplemented(
  1364. gb_type=f"unsupported type.__dict__['{attr_name}'].__get__ call",
  1365. context=f"call_function {self}, args: {args}, kwargs: {kwargs}",
  1366. explanation=f"`torch.compile` only supports calling type.__dict__['{attr_name}'].__get__ "
  1367. "on a single constant argument (i.e. a type).",
  1368. hints=[
  1369. f"Make sure your call to type.__dict__['{attr_name}'].__get__ only has "
  1370. "one positional argument (no keyword arguments).",
  1371. f"Make sure the argument to type.__dict__['{attr_name}'].__get__ is a constant "
  1372. "(i.e. type). For example, `object`, `int`, `MyCustomClass`.",
  1373. *graph_break_hints.SUPPORTABLE,
  1374. ],
  1375. )
  1376. return super().call_function(tx, args, kwargs)
  1377. def is_python_constant(self) -> Literal[True]:
  1378. return True
  1379. def as_python_constant(self) -> types.MethodWrapperType:
  1380. return self.method_wrapper
  1381. def is_python_hashable(self) -> Literal[True]:
  1382. return True
  1383. def get_python_hash(self) -> int:
  1384. return hash(self.as_python_constant())
  1385. def is_python_equal(self, other: object) -> bool:
  1386. return (
  1387. isinstance(other, VariableTracker)
  1388. and self.as_python_constant() == other.as_python_constant()
  1389. )
  1390. class GetSetDescriptorVariable(VariableTracker):
  1391. def __init__(self, desc: types.GetSetDescriptorType, **kwargs: Any) -> None:
  1392. super().__init__(**kwargs)
  1393. self.desc = desc
  1394. def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
  1395. if name == "__get__" and self.source:
  1396. source = AttrSource(self.source, "__get__")
  1397. return VariableTracker.build(tx, self.desc.__get__, source)
  1398. else:
  1399. return super().var_getattr(tx, name)
  1400. def is_python_constant(self) -> Literal[True]:
  1401. return True
  1402. def as_python_constant(self) -> types.GetSetDescriptorType:
  1403. return self.desc
  1404. class PythonModuleVariable(VariableTracker):
  1405. _nonvar_fields = {
  1406. "value",
  1407. "is_torch",
  1408. *VariableTracker._nonvar_fields,
  1409. }
  1410. def __init__(self, value: types.ModuleType, **kwargs: Any) -> None:
  1411. super().__init__(**kwargs)
  1412. self.value = value
  1413. self.is_torch = self.value is torch or self.value.__name__.startswith("torch.")
  1414. def python_type(self) -> type[types.ModuleType]:
  1415. return types.ModuleType
  1416. def as_python_constant(self) -> types.ModuleType:
  1417. return self.value
  1418. def __repr__(self) -> str:
  1419. return f"PythonModuleVariable({self.value})"
  1420. def call_obj_hasattr(
  1421. self, tx: "InstructionTranslator", name: str
  1422. ) -> ConstantVariable:
  1423. result = hasattr(self.value, name)
  1424. return variables.ConstantVariable.create(result)
  1425. def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
  1426. if tx.output.side_effects.has_pending_mutation_of_attr(self, name):
  1427. return tx.output.side_effects.load_attr(self, name)
  1428. attr_value = None
  1429. if self.is_torch or name not in self.value.__dict__:
  1430. try:
  1431. attr_value = getattr(self.value, name)
  1432. except AttributeError:
  1433. raise_observed_exception(AttributeError, tx)
  1434. else:
  1435. attr_value = self.value.__dict__[name]
  1436. source = self.source and AttrSource(self.source, name)
  1437. return VariableTracker.build(tx, attr_value, source)
  1438. class TypingVariable(VariableTracker):
  1439. def __init__(self, value: Any, **kwargs: Any) -> None:
  1440. super().__init__(**kwargs)
  1441. self.value = value
  1442. def call_method(
  1443. self,
  1444. tx: "InstructionTranslator",
  1445. name: str,
  1446. args: list[VariableTracker],
  1447. kwargs: dict[str, VariableTracker],
  1448. ) -> VariableTracker:
  1449. # Create a new typing variable, e.g., `List[int]`
  1450. if name == "__getitem__" and len(args) == 1:
  1451. new_typing = self.value[args[0].as_python_constant()]
  1452. return TypingVariable(new_typing)
  1453. elif name == "__eq__":
  1454. if len(args) == 1 and not kwargs:
  1455. result = istype(args[0], TypingVariable) and self.value == args[0].value
  1456. return variables.ConstantVariable.create(result)
  1457. unimplemented(
  1458. gb_type="unsupported method call on `typing` variable",
  1459. context=f"typing variable: {self.value}, method name: {name}, args: {args}, kwargs: {kwargs}",
  1460. explanation=f"`torch.compile` does not support method call `{name}` on `typing` variable f{self.value}.",
  1461. hints=[
  1462. f"Avoid calling the {name} method on {self.value}.",
  1463. *graph_break_hints.SUPPORTABLE,
  1464. ],
  1465. )
  1466. def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
  1467. from .builder import SourcelessBuilder, VariableBuilder
  1468. if name in cmp_name_to_op_mapping:
  1469. return variables.GetAttrVariable(self, name)
  1470. if tx.output.side_effects.has_pending_mutation_of_attr(self, name):
  1471. return tx.output.side_effects.load_attr(self, name)
  1472. value = getattr(self.value, name)
  1473. if self.source:
  1474. attr_source = AttrSource(self.source, name)
  1475. return VariableBuilder(tx, attr_source)(value)
  1476. else:
  1477. return SourcelessBuilder.create(tx, value)
  1478. def as_python_constant(self) -> Any:
  1479. return self.value
  1480. def reconstruct(self, codegen: "PyCodegen") -> None:
  1481. if not isinstance(self.value, types.GenericAlias):
  1482. return super().reconstruct(codegen)
  1483. # We're just trying to load the type here. Reconstructing the type from
  1484. # scratch is tricky - for a type like `typing.List[int]` we'd need to
  1485. # deconstruct the origin and args. The origin for `List[int]` is `list`
  1486. # and the args is `(int,)`. When we recombine those we get the parts
  1487. # back and need to emit code for:
  1488. #
  1489. # `typing.List[int]`
  1490. #
  1491. # But it's # worse than that - what if `typing` isn't in the globals (or
  1492. # was loaded like `import typing as _typing ; _typing.List[int]`?) so we
  1493. # really need to do something like:
  1494. #
  1495. # `sys.modules["typing"].List[int]`
  1496. #
  1497. # Argh - but what if they rewrote the global `int`? So we have to do:
  1498. #
  1499. # `sys.modules["typing"].List[sys.modules["builtins"].int]`
  1500. #
  1501. # But where do we get `sys`? What if they never imported it or have
  1502. # something ELSE called `sys`?
  1503. #
  1504. # Let's skip all that noise and just emit it as a simple const.
  1505. #
  1506. codegen.append_output(codegen.create_load_const(self.value))
  1507. def is_python_hashable(self) -> Literal[True]:
  1508. return True
  1509. def get_python_hash(self) -> int:
  1510. return hash(self.as_python_constant())
  1511. def is_python_equal(self, other: object) -> bool:
  1512. return (
  1513. isinstance(other, VariableTracker)
  1514. and self.as_python_constant() == other.as_python_constant()
  1515. )
  1516. @functools.lru_cache(maxsize=1)
  1517. def get_np_to_tnp_map() -> dict[types.BuiltinFunctionType, types.FunctionType]:
  1518. """
  1519. This generates a mapping from numpy modules to their torch._numpy
  1520. modules equivalents.
  1521. """
  1522. from ..utils import NP_TO_TNP_MODULE
  1523. np_fn_to_tnp_fn = {}
  1524. for np_mod, tnp_mod in NP_TO_TNP_MODULE.items():
  1525. for fn_name, tnp_fn in tnp_mod.__dict__.items():
  1526. if callable(tnp_fn):
  1527. # some internal details do leak from tnp
  1528. # which are not part of numpy API.
  1529. if np_fn := getattr(np_mod, fn_name, None):
  1530. np_fn_to_tnp_fn[np_fn] = tnp_fn
  1531. return np_fn_to_tnp_fn
  1532. @functools.lru_cache(maxsize=1)
  1533. def get_tnp_to_np_map() -> dict[types.FunctionType, types.BuiltinFunctionType]:
  1534. """
  1535. This is just the reverse mapping of get_np_to_tnp_map() - mapping from
  1536. torch._numpy modules to numpy equivalents.
  1537. """
  1538. m = get_np_to_tnp_map()
  1539. return {v: k for k, v in m.items()}
  1540. class NumpyVariable(VariableTracker):
  1541. """
  1542. Wrapper around `numpy.*`. Currently, is able to trace a small subset of numpy functions as well as numpy dtypes.
  1543. """
  1544. constant_fold_functions = (tnp.issubdtype,)
  1545. def __init__(self, value: Any, **kwargs: Any) -> None:
  1546. super().__init__(**kwargs)
  1547. self.value = value
  1548. @classmethod
  1549. def can_constant_fold_through(cls, fn: types.FunctionType) -> bool:
  1550. mod = fn.__module__.split(".")
  1551. assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"]
  1552. return fn in cls.constant_fold_functions
  1553. @classmethod
  1554. def get_constant_collection_for_func(cls, fn: types.FunctionType) -> Any:
  1555. mod = fn.__module__.split(".")
  1556. assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"]
  1557. return np_constant_collections_map.get(fn)
  1558. def call_function(
  1559. self,
  1560. tx: "InstructionTranslator",
  1561. args: Sequence[VariableTracker],
  1562. kwargs: dict[str, VariableTracker],
  1563. ) -> VariableTracker:
  1564. if not config.trace_numpy:
  1565. unimplemented(
  1566. gb_type="attempted to trace numpy function with config.trace_numpy=False",
  1567. context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs}",
  1568. explanation=f"Attempted to trace numpy function {self.value} "
  1569. "while `torch._dynamo.config.trace_numpy` was set to False.",
  1570. hints=[
  1571. "Set `torch._dynamo.config.trace_numpy` to True to trace numpy functions.",
  1572. ],
  1573. )
  1574. from ..utils import numpy_to_tensor_wrapper
  1575. from .tensor import NumpyNdarrayVariable
  1576. func = get_np_to_tnp_map().get(self.value)
  1577. if func is None:
  1578. unimplemented(
  1579. gb_type="attempted to trace numpy function unsupported by PyTorch",
  1580. context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs} (corresponding torch function: {func})",
  1581. explanation=f"Can't find numpy numpy function {self.value} in torch._numpy.",
  1582. hints=[
  1583. *graph_break_hints.SUPPORTABLE,
  1584. ],
  1585. )
  1586. # We are dealing with a function that produces a const collection type (np.dtype, np.iinfo/np.finfo)
  1587. assert func is not None
  1588. if (
  1589. collection_variable_typ := self.get_constant_collection_for_func(func)
  1590. ) is not None:
  1591. try:
  1592. return collection_variable_typ(
  1593. self.value(
  1594. *[x.as_python_constant() for x in args],
  1595. **{k: v.as_python_constant() for k, v in kwargs.items()},
  1596. )
  1597. )
  1598. except AsPythonConstantNotImplementedError:
  1599. unimplemented(
  1600. gb_type="numpy function that produces a const collection type encountered non-const arguments",
  1601. context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs} (corresponding torch function: {func})",
  1602. explanation=f"numpy function {self.value} that produces a const collection type "
  1603. "(e.g. np.dtype, np.iinfo/np.finfo) "
  1604. "received arguments that are not constant.",
  1605. hints=[
  1606. *graph_break_hints.USER_ERROR,
  1607. ],
  1608. )
  1609. else:
  1610. if (
  1611. func.__module__ == "torch._numpy.random"
  1612. and config.use_numpy_random_stream
  1613. ):
  1614. unimplemented(
  1615. gb_type="attempted to trace torch._numpy.random function with config.use_numpy_random_stream=True",
  1616. context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs} (corresponding torch function: {func})",
  1617. explanation=f"Attempted to trace {self.value} when `torch._dynamo.config.use_numpy_random_stream` "
  1618. "is set to True.",
  1619. hints=[
  1620. "Set `torch._dynamo.config.use_numpy_random_stream` to False.",
  1621. f"Avoid calling {self.value}.",
  1622. ],
  1623. )
  1624. args, kwargs = NumpyNdarrayVariable.patch_args(func.__name__, args, kwargs)
  1625. if self.can_constant_fold_through(func) and (
  1626. check_unspec_or_constant_args(args, kwargs)
  1627. ):
  1628. # constant fold
  1629. return variables.ConstantVariable.create(
  1630. self.as_python_constant()(
  1631. *[x.as_python_constant() for x in args],
  1632. **{k: v.as_python_constant() for k, v in kwargs.items()},
  1633. ),
  1634. )
  1635. # TODO Add all the functions that go from constants to constants to can_constant_fold_through
  1636. proxy = tx.output.create_proxy(
  1637. "call_function",
  1638. numpy_to_tensor_wrapper(func),
  1639. *proxy_args_kwargs(args, kwargs),
  1640. )
  1641. return NumpyNdarrayVariable.create(tx, proxy)
  1642. def call_method(
  1643. self,
  1644. tx: "InstructionTranslator",
  1645. name: str,
  1646. args: list[VariableTracker],
  1647. kwargs: dict[str, VariableTracker],
  1648. ) -> VariableTracker:
  1649. unimplemented(
  1650. gb_type="attempted to trace numpy.* function as a method",
  1651. context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs}",
  1652. explanation="Tracing numpy.* functions as methods is not supported.",
  1653. hints=[
  1654. *graph_break_hints.DIFFICULT,
  1655. ],
  1656. )
  1657. def as_python_constant(self) -> BuiltinFunctionType:
  1658. return self.value
  1659. def as_proxy(self) -> Any:
  1660. if config.trace_numpy:
  1661. # Can replace with EnumType once we drop 3.10 support
  1662. if isinstance(self.value, enum.EnumMeta):
  1663. # This is mostly for np._CopyMode
  1664. return self.value
  1665. if isinstance(self.value, type):
  1666. # This handles numpy dtype attributes such as np.float32
  1667. # We return a string as we don't want to serialize non-PyTorch objects in the output FX graph
  1668. # In torch/_numpy we normalize strings to their dtypes when the input is a dtype, as NumPy does
  1669. return self.value.__name__
  1670. return super().as_proxy()
  1671. def is_python_hashable(self) -> Literal[True]:
  1672. return True
  1673. def get_python_hash(self) -> int:
  1674. return hash(self.as_python_constant())
  1675. def is_python_equal(self, other: object) -> bool:
  1676. return (
  1677. isinstance(other, VariableTracker)
  1678. and self.as_python_constant() == other.as_python_constant()
  1679. )
  1680. # Used to keep track of NULLs pushed on the stack for Python 3.11 function calls
  1681. class NullVariable(VariableTracker):
  1682. def __init__(self, **kwargs: Any) -> None:
  1683. super().__init__(**kwargs)
  1684. def __repr__(self) -> str:
  1685. return "NullVariable"
  1686. def reconstruct(self, codegen: "PyCodegen") -> None:
  1687. if sys.version_info < (3, 11):
  1688. unimplemented(
  1689. gb_type="cannot reconstruct NullVariable in Python < 3.11",
  1690. context="",
  1691. explanation="Attempted to generate PUSH_NULL instruction in Python < 3.11; "
  1692. "where this instruction does not exist.",
  1693. hints=[
  1694. *graph_break_hints.DYNAMO_BUG,
  1695. ],
  1696. )
  1697. codegen.append_output(create_instruction("PUSH_NULL"))
  1698. class DeletedVariable(VariableTracker):
  1699. """Marker used to implement delattr()"""
  1700. class StringFormatVariable(VariableTracker):
  1701. """
  1702. Represents a call to str.format(), we delay calling format until after the graph.
  1703. """
  1704. _nonvar_fields = {"format_string", *VariableTracker._nonvar_fields}
  1705. @classmethod
  1706. def create(
  1707. cls,
  1708. format_string: str,
  1709. sym_args: Sequence[VariableTracker],
  1710. sym_kwargs: dict[str, VariableTracker],
  1711. ) -> VariableTracker:
  1712. if all(
  1713. x.is_python_constant()
  1714. for x in itertools.chain(sym_args, sym_kwargs.values())
  1715. ):
  1716. return variables.ConstantVariable.create(
  1717. format_string.format(
  1718. *[v.as_python_constant() for v in sym_args],
  1719. **{k: v.as_python_constant() for k, v in sym_kwargs.items()},
  1720. )
  1721. )
  1722. return cls(format_string, list(sym_args), dict(sym_kwargs))
  1723. def __init__(
  1724. self,
  1725. format_string: str,
  1726. sym_args: Sequence[VariableTracker],
  1727. sym_kwargs: dict[str, VariableTracker],
  1728. **kwargs: Any,
  1729. ) -> None:
  1730. super().__init__(**kwargs)
  1731. assert isinstance(format_string, str)
  1732. self.format_string = format_string
  1733. self.sym_args = sym_args
  1734. self.sym_kwargs = sym_kwargs
  1735. def __repr__(self) -> str:
  1736. return f"{self.__class__.__name__}({self.format_string!r}, {self.sym_args!r}, {self.sym_kwargs!r})"
  1737. def reconstruct(self, codegen: "PyCodegen") -> None:
  1738. codegen.add_push_null(
  1739. lambda: codegen.extend_output(
  1740. [
  1741. codegen.create_load_const(self.format_string),
  1742. codegen.create_load_attr("format"),
  1743. ]
  1744. ),
  1745. call_function_ex=True,
  1746. )
  1747. codegen(variables.TupleVariable(list(self.sym_args)))
  1748. kwargs = {
  1749. variables.ConstantVariable.create(k): v for k, v in self.sym_kwargs.items()
  1750. }
  1751. codegen(variables.ConstDictVariable(kwargs))
  1752. codegen.extend_output(create_call_function_ex(True, False))
  1753. class ObjectVariable(VariableTracker):
  1754. # placeholder for unknown / opaque values
  1755. def __init__(self, value: object, **kwargs: Any) -> None:
  1756. super().__init__(**kwargs)
  1757. self.value = value
  1758. def python_type(self) -> type[object]:
  1759. return object
  1760. class DebuggingVariable(VariableTracker):
  1761. """
  1762. Represents a call to a debugging function like print(), or something
  1763. registered to config.reorderable_logging_functions.
  1764. """
  1765. def __init__(self, value: Any, **kwargs: Any) -> None:
  1766. super().__init__(**kwargs)
  1767. self.value = value
  1768. @staticmethod
  1769. def is_reorderable_logging_function(
  1770. obj: Any,
  1771. ) -> TypeGuard[types.FunctionType | types.BuiltinFunctionType]:
  1772. return (
  1773. callable(obj)
  1774. and isinstance(obj, (types.FunctionType, types.BuiltinFunctionType))
  1775. and obj in torch._dynamo.config.reorderable_logging_functions
  1776. )
  1777. # type: ignore[override]
  1778. def call_function(
  1779. self,
  1780. tx: "InstructionTranslator",
  1781. args: Sequence[VariableTracker],
  1782. kwargs: dict[str, VariableTracker],
  1783. ) -> None:
  1784. if tx.export:
  1785. # For export cases, we can just make debugging functions no-ops
  1786. return
  1787. if not self.can_reorder_logs(self.value, args, kwargs):
  1788. unimplemented(
  1789. gb_type="attempted to reorder a debugging function that can't actually be reordered",
  1790. context=f"fn: {self.value}, args: {args}, kwargs: {kwargs}",
  1791. explanation="`torch.compile` can only reorder functions where the arguments "
  1792. "are Tensors, constants, or string formatters.",
  1793. hints=[
  1794. f"Avoid calling the logging function {self.value} with args that are not supported.",
  1795. ],
  1796. )
  1797. tx.debug_locals.append((self, list(args)))
  1798. def reconstruct(self, codegen: "PyCodegen") -> None:
  1799. assert self.source is not None
  1800. return self.source.reconstruct(codegen)
  1801. @staticmethod
  1802. def can_reorder_logs(fn: Any, args: Sequence[Any], kwargs: dict[str, Any]) -> bool:
  1803. """
  1804. Run some additional checks for what sort of function calls can we
  1805. actually reorder.
  1806. """
  1807. allowed_input_types = (
  1808. variables.TensorVariable,
  1809. variables.ConstantVariable,
  1810. StringFormatVariable,
  1811. )
  1812. flat_args = pytree.tree_leaves([args, kwargs])
  1813. for arg in flat_args:
  1814. if not isinstance(arg, allowed_input_types):
  1815. return False
  1816. return True
  1817. class IgnoredFunctionVariable(VariableTracker):
  1818. """
  1819. Represents a call to an arbitrary function that should be ignored.
  1820. """
  1821. def __init__(self, value: Any, **kwargs: Any) -> None:
  1822. super().__init__(**kwargs)
  1823. self.value = value
  1824. def call_function(
  1825. self,
  1826. tx: "InstructionTranslator",
  1827. args: Sequence[VariableTracker],
  1828. kwargs: dict[str, VariableTracker],
  1829. ) -> VariableTracker:
  1830. return variables.CONSTANT_VARIABLE_NONE
  1831. class LoggingLoggerVariable(VariableTracker):
  1832. """
  1833. Represents a call to any logging.Logger methods.
  1834. """
  1835. def __init__(self, value: logging.Logger, **kwargs: Any) -> None:
  1836. super().__init__(**kwargs)
  1837. self.value = value
  1838. def call_method(
  1839. self,
  1840. tx: "InstructionTranslator",
  1841. name: str,
  1842. args: list[VariableTracker],
  1843. kwargs: dict[str, VariableTracker],
  1844. ) -> VariableTracker:
  1845. if tx.export:
  1846. # For export cases, we can just make logging functions no-ops.
  1847. return variables.CONSTANT_VARIABLE_NONE
  1848. method = getattr(self.value, name, None)
  1849. function = getattr(method, "__func__", None)
  1850. # Unified ignore set
  1851. ignore_set = torch._dynamo.config.ignore_logging_functions
  1852. if method in ignore_set or function in ignore_set:
  1853. return variables.CONSTANT_VARIABLE_NONE
  1854. unimplemented(
  1855. gb_type="logging.Logger method not supported for non-export cases",
  1856. context=f"method: {self.value}.{name}, args: {args}, kwargs: {kwargs}",
  1857. explanation="logging.Logger methods are not supported for non-export cases.",
  1858. hints=[
  1859. "Add the logging method to `torch._dynamo.config.ignore_logging_functions`.",
  1860. ],
  1861. )
  1862. class ConstantLikeVariable(VariableTracker):
  1863. """self.value is a compile-time constant, but not a literal"""
  1864. try:
  1865. from numpy import (
  1866. dtype as np_dtype,
  1867. floating as np_floating,
  1868. generic as np_generic,
  1869. )
  1870. except ImportError:
  1871. # type: ignore[misc, assignment]
  1872. np_floating = type("invalid_type", (), {})
  1873. # type: ignore[misc, assignment]
  1874. np_dtype = type("invalid_type", (), {})
  1875. def __init__(self, value: Any, **kwargs: Any) -> None:
  1876. super().__init__(**kwargs)
  1877. self.value = value
  1878. @property
  1879. def _error_prefix(self) -> str:
  1880. """Dynamically compute the prefix from the value's type"""
  1881. t = type(self.value)
  1882. # For builtins (int, str, etc.), just return the name
  1883. if t.__module__ == "builtins":
  1884. return t.__qualname__
  1885. return f"{t.__module__}.{t.__qualname__}"
  1886. def as_python_constant(self) -> Any:
  1887. return self.value
  1888. def call_method(
  1889. self,
  1890. tx: "InstructionTranslator",
  1891. name: str,
  1892. args: list[VariableTracker],
  1893. kwargs: dict[str, VariableTracker],
  1894. ) -> VariableTracker:
  1895. # pyrefly: ignore [implicit-any]
  1896. cargs, ckwargs = [], {}
  1897. try:
  1898. # we only support constant propagation for methods
  1899. cargs = [x.as_python_constant() for x in args]
  1900. ckwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
  1901. except NotImplementedError:
  1902. unimplemented(
  1903. gb_type="constant-like method call with non-constant args",
  1904. context=f"{self._error_prefix}.{name}(*{args}, **{kwargs})",
  1905. explanation=f"Attempted to call {self._error_prefix}.{name} with non-constant args.",
  1906. hints=[
  1907. "Ensure that the args to the method call are constant (int, str, etc.).",
  1908. ],
  1909. )
  1910. result = getattr(self.value, name)(*cargs, **ckwargs)
  1911. if variables.ConstantVariable.is_literal(result):
  1912. return variables.ConstantVariable.create(result)
  1913. if isinstance(result, re.Match):
  1914. return ConstantLikeVariable(result)
  1915. unimplemented(
  1916. gb_type="constant-like method call with unsupported return type",
  1917. context=f"{self._error_prefix}.{name}(*{args}, **{kwargs}) returned {result}",
  1918. explanation=f"Attempted to call {self._error_prefix}.{name}, got unsupported return value {result}.",
  1919. hints=[
  1920. *graph_break_hints.SUPPORTABLE,
  1921. ],
  1922. )
  1923. def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
  1924. result = getattr(self.value, name)
  1925. if isinstance(result, self.np_floating):
  1926. result = float(result)
  1927. if isinstance(result, self.np_dtype):
  1928. return NumpyDTypeVariable(result)
  1929. if isinstance(result, type) and issubclass(result, self.np_generic):
  1930. # things like x.dtype.type
  1931. return NumpyVariable(result)
  1932. if variables.ConstantVariable.is_literal(result):
  1933. return variables.ConstantVariable.create(result)
  1934. return GetAttrVariable(self, name)
  1935. class TorchVersionVariable(ConstantLikeVariable):
  1936. _error_prefix = "torch.__version__"
  1937. def __init__(self, **kwargs: Any) -> None:
  1938. kwargs.setdefault("value", torch.__version__)
  1939. assert kwargs["value"] is torch.__version__
  1940. super().__init__(**kwargs)
  1941. class NumpyDTypeVariable(ConstantLikeVariable):
  1942. def as_proxy(self) -> str:
  1943. """Similar to how numpy dtype descriptors (e.g. np.float32 ) are handled by NumpyVariable:
  1944. np.dtype() objects are serialized as strings, torch._numpy wrappers will normalize to the torch dtype.
  1945. This also handles unsupported things nicely (i.e. structured arrays and object arrays).
  1946. """
  1947. return self.value.type.__name__
  1948. np_constant_collections_map = {
  1949. tnp.finfo: ConstantLikeVariable,
  1950. tnp.iinfo: ConstantLikeVariable,
  1951. tnp.dtype: NumpyDTypeVariable,
  1952. }
  1953. class RandomClassVariable(VariableTracker):
  1954. """random.Random"""
  1955. def __init__(self, **kwargs: Any) -> None:
  1956. super().__init__(**kwargs)
  1957. def call_function(
  1958. self,
  1959. tx: "InstructionTranslator",
  1960. args: Sequence[VariableTracker],
  1961. kwargs: dict[str, VariableTracker],
  1962. ) -> "RandomVariable":
  1963. if len(args) > 1 or kwargs:
  1964. unimplemented(
  1965. gb_type="random.Random() with improper arguments",
  1966. context=f"args: {args}, kwargs: {kwargs}",
  1967. explanation="random.Random() with > 1 arg or with kwargs is not supported.",
  1968. hints=[
  1969. *graph_break_hints.USER_ERROR,
  1970. ],
  1971. )
  1972. seed = variables.CONSTANT_VARIABLE_NONE if len(args) == 0 else args[0]
  1973. return RandomVariable(
  1974. seed=seed, mutation_type=variables.base.ValueMutationNew()
  1975. )
  1976. class RandomVariable(VariableTracker):
  1977. """random.Random()
  1978. Implemented by wrapping a VariableTracker around a random.Random object.
  1979. The supported methods for the random.Random object cannot be overridden.
  1980. Assumes that random objects behave the same given a set seed or state.
  1981. """
  1982. _nonvar_fields = {
  1983. "random",
  1984. *VariableTracker._nonvar_fields,
  1985. }
  1986. _supported_fn_names = {
  1987. "random",
  1988. "randint",
  1989. "randrange",
  1990. "uniform",
  1991. }
  1992. def __init__(
  1993. self,
  1994. rand: random.Random | None = None,
  1995. seed: VariableTracker | None = None,
  1996. **kwargs: Any,
  1997. ) -> None:
  1998. super().__init__(**kwargs)
  1999. if rand is not None:
  2000. assert self.is_supported_random_obj(rand)
  2001. self.random = random.Random()
  2002. self.random.setstate(rand.getstate())
  2003. else:
  2004. seed = seed.as_python_constant() if seed is not None else None
  2005. self.random = random.Random(seed)
  2006. def python_type(self) -> type[random.Random]:
  2007. return random.Random
  2008. def as_python_constant(self) -> random.Random:
  2009. return self.random
  2010. @staticmethod
  2011. def is_supported_random_obj(val: Random) -> bool:
  2012. if type(val) is not random.Random:
  2013. return False
  2014. for name in itertools.chain(
  2015. RandomVariable._supported_fn_names, ("seed", "getstate", "setstate")
  2016. ):
  2017. if not hasattr(val, name):
  2018. return False
  2019. meth = getattr(val, name)
  2020. if inspect.isbuiltin(meth):
  2021. # e.g. random.Random.random
  2022. if meth != getattr(random.Random, name).__get__(val):
  2023. return False
  2024. else:
  2025. if getattr(meth, "__func__", None) is not getattr(random.Random, name):
  2026. return False
  2027. return True
  2028. @staticmethod
  2029. def check_state(state: tuple[int, tuple[int, ...], float | None]) -> None:
  2030. assert type(state) is tuple
  2031. assert type(state[0]) is int
  2032. assert type(state[1]) is tuple
  2033. assert all(type(x) is int for x in state[1])
  2034. assert state[2] is None or type(state[2]) is float
  2035. @staticmethod
  2036. def wrap_state(state: tuple[int, tuple[int, ...], float | None]) -> TupleVariable:
  2037. RandomVariable.check_state(state)
  2038. return variables.TupleVariable(
  2039. [
  2040. variables.ConstantVariable.create(state[0]),
  2041. variables.TupleVariable(
  2042. [variables.ConstantVariable.create(x) for x in state[1]]
  2043. ),
  2044. variables.ConstantVariable.create(state[2]),
  2045. ]
  2046. )
  2047. @staticmethod
  2048. def unwrap_state(
  2049. state: VariableTracker,
  2050. ) -> tuple[int, tuple[int, ...], float | None]:
  2051. state_obj = state.as_python_constant()
  2052. RandomVariable.check_state(state_obj)
  2053. return state_obj
  2054. def call_method(
  2055. self,
  2056. tx: "InstructionTranslator",
  2057. name: str,
  2058. args: list[VariableTracker],
  2059. kwargs: dict[str, VariableTracker],
  2060. ) -> VariableTracker:
  2061. if name == "seed":
  2062. tx.output.side_effects.mutation(self)
  2063. self.random.seed(
  2064. *[x.as_python_constant() for x in args],
  2065. **{key: val.as_python_constant() for key, val in kwargs.items()},
  2066. )
  2067. return variables.CONSTANT_VARIABLE_NONE
  2068. elif name == "getstate":
  2069. return self.wrap_state(self.random.getstate())
  2070. elif name == "setstate":
  2071. tx.output.side_effects.mutation(self)
  2072. self.random.setstate(self.unwrap_state(args[0]))
  2073. return variables.CONSTANT_VARIABLE_NONE
  2074. elif name in self._supported_fn_names:
  2075. tx.output.side_effects.mutation(self)
  2076. state = self.random.getstate()
  2077. def call_random_meth(*args: Any, **kwargs: Any) -> Any:
  2078. r = random.Random()
  2079. r.setstate(state)
  2080. return getattr(r, name)(*args, **kwargs)
  2081. # self.random state not actually updated by call_random_meth, so update here
  2082. # by calling the method
  2083. getattr(self.random, name)(
  2084. *[x.as_python_constant() for x in args],
  2085. **{k: v.as_python_constant() for k, v in kwargs.items()},
  2086. )
  2087. return call_random_fn(tx, call_random_meth, args, kwargs)
  2088. return super().call_method(tx, name, args, kwargs)
  2089. def reconstruct(self, codegen: "PyCodegen") -> None:
  2090. codegen.add_push_null(
  2091. lambda: codegen.extend_output(
  2092. [
  2093. codegen.create_load_python_module(random),
  2094. codegen.create_load_attr("Random"),
  2095. ]
  2096. )
  2097. )
  2098. codegen.call_function(0, False)
  2099. # NOTE using add_push_null may result in NULL being duplicated
  2100. # so defer the push_null to call_function
  2101. codegen.dup_top()
  2102. codegen.load_attr("setstate")
  2103. codegen(self.wrap_state(self.random.getstate()))
  2104. codegen.call_function(1, True)
  2105. codegen.pop_top()
  2106. class WeakRefVariable(VariableTracker):
  2107. @staticmethod
  2108. # pyrefly: ignore[bad-param-name-override]
  2109. def build(
  2110. tx: "InstructionTranslator",
  2111. weakref_value: weakref.ReferenceType[Any],
  2112. source: Source | None,
  2113. **options: Any,
  2114. ) -> "WeakRefVariable":
  2115. assert source is not None
  2116. callback = weakref_value.__callback__
  2117. callback_source = source and AttrSource(source, "__callback__")
  2118. callback_vt = VariableTracker.build(tx, callback, callback_source)
  2119. referent = weakref_value()
  2120. source = source and WeakRefCallSource(source)
  2121. referent_vt = VariableTracker.build(tx, referent, source)
  2122. options["source"] = source
  2123. return WeakRefVariable(referent_vt, callback_vt, **options)
  2124. def __init__(
  2125. self, referent_vt: VariableTracker, callback_vt: VariableTracker, **options: Any
  2126. ) -> None:
  2127. super().__init__(**options)
  2128. self.referent_vt = referent_vt
  2129. self.callback_vt = callback_vt
  2130. def call_function(
  2131. self,
  2132. tx: "InstructionTranslator",
  2133. args: Sequence[VariableTracker],
  2134. kwargs: dict[str, VariableTracker],
  2135. ) -> VariableTracker:
  2136. return self.referent_vt
  2137. def reconstruct(self, codegen: "PyCodegen") -> None:
  2138. codegen.add_push_null(lambda: codegen.load_import_from("weakref", "ref"))
  2139. codegen(self.referent_vt)
  2140. codegen(self.callback_vt)
  2141. codegen.extend_output(create_call_function(2, False))
  2142. def is_python_hashable(self) -> bool:
  2143. return self.referent_vt.is_python_hashable()
  2144. def get_python_hash(self) -> int:
  2145. # weakref relies on the referent's hash
  2146. return self.referent_vt.get_python_hash()
  2147. def is_python_equal(self, other: object) -> bool:
  2148. if not isinstance(other, WeakRefVariable):
  2149. return False
  2150. return self.referent_vt.is_python_equal(other.referent_vt)