functions.py 122 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274
  1. """
  2. Function-related variable tracking classes for Dynamo's symbolic execution.
  3. This module contains classes that track different types of functions during graph
  4. compilation, including:
  5. - User-defined functions and methods
  6. - Built-in functions and methods
  7. - Wrapped functions (e.g. from decorators)
  8. - Special function types (e.g. functools.partial)
  9. - Triton kernels and related function types
  10. These classes are responsible for:
  11. - Tracking function calls and their arguments
  12. - Managing function closures and cell variables
  13. - Handling function attributes and special methods
  14. - Maintaining guards for function identity and closure contents
  15. - Supporting function inlining and specialization
  16. - Enabling proper symbolic execution of different function types
  17. The variable trackers here work together with the rest of Dynamo to enable
  18. accurate graph capture while handling Python's various function-related behaviors.
  19. """
  20. import builtins
  21. import functools
  22. import inspect
  23. import itertools
  24. import logging
  25. import os
  26. import sys
  27. import traceback
  28. import types
  29. from collections import namedtuple
  30. from collections.abc import Callable, Sequence
  31. from types import CellType, FunctionType
  32. from typing import Any, cast, Literal, Optional, TYPE_CHECKING, TypeVar
  33. from typing_extensions import Never
  34. from weakref import WeakKeyDictionary
  35. import torch
  36. from torch._dynamo.exc import get_stack_above_dynamo
  37. from torch._guards import Source
  38. from torch.utils._pytree import is_namedtuple_class
  39. from .. import config, graph_break_hints, polyfills, variables
  40. from ..bytecode_transformation import create_call_function, create_rot_n, is_generator
  41. from ..exc import (
  42. format_frame_info,
  43. get_dynamo_observed_exception,
  44. handle_observed_exception,
  45. InfiniteGeneratorError,
  46. ObservedException,
  47. ObservedGeneratorExit,
  48. ObservedUserStopIteration,
  49. raise_observed_exception,
  50. StepUnsupported,
  51. unimplemented,
  52. Unsupported,
  53. )
  54. from ..guards import GuardBuilder, install_guard
  55. from ..source import (
  56. AttrSource,
  57. CellContentsSource,
  58. ClosureSource,
  59. ConstantSource,
  60. DefaultsSource,
  61. GetItemSource,
  62. ImportSource,
  63. SkipGuardSource,
  64. TypeSource,
  65. )
  66. from ..utils import (
  67. check_constant_args,
  68. check_unspec_or_constant_args,
  69. cmp_name_to_op_mapping,
  70. identity,
  71. is_function,
  72. is_wrapper_or_member_descriptor,
  73. istype,
  74. make_cell,
  75. )
  76. from .base import (
  77. AsPythonConstantNotImplementedError,
  78. AttributeMutationNew,
  79. raise_type_error_exc,
  80. ValueMutationNew,
  81. VariableTracker,
  82. )
  83. from .constant import CONSTANT_VARIABLE_NONE, ConstantVariable
  84. try:
  85. from torch.distributed.fsdp._fully_shard import _fsdp_param_group
  86. except ModuleNotFoundError:
  87. _fsdp_param_group = None # type: ignore[assignment]
  88. if TYPE_CHECKING:
  89. from torch._dynamo.codegen import PyCodegen
  90. from torch._dynamo.symbolic_convert import (
  91. InliningGeneratorInstructionTranslator,
  92. InliningInstructionTranslator,
  93. InstructionTranslator,
  94. InstructionTranslatorBase,
  95. )
  96. from torch._dynamo.variables.ctx_manager import ContextWrappingVariable
  97. from torch._higher_order_ops.triton_kernel_wrap import (
  98. TritonGridType,
  99. TritonKernelType,
  100. )
  101. from .lists import BaseListVariable, ListVariable
  102. from .tensor import TensorVariable
  103. _F = TypeVar("_F", bound=Callable[..., Any])
  104. CO_VARARGS = 0x04
  105. CO_VARKEYWORDS = 0x08
  106. _SUPPORTED_TREE_MAP_KWARGS = frozenset({"namespace", "none_is_leaf", "is_leaf"})
  107. _TREE_MAP_ONLY_SUPPORTED_KWARGS = frozenset({"is_leaf"})
  108. PT2_ISSUE_TRACKER_URL = "https://github.com/pytorch/pytorch/issues/new?&labels=oncall%3A+pt2&projects=&template=pt2-bug-report.yml"
  109. # Module-level cache keyed by the function object
  110. _spec_cache: WeakKeyDictionary[Any, Any] = WeakKeyDictionary()
  111. # Raised when get_function() cannot convert a nested function to a Python function.
  112. class ClosureConversionError(NotImplementedError):
  113. pass
  114. @functools.lru_cache
  115. def get_pytree_SUPPORTED_NODES_source() -> AttrSource:
  116. return AttrSource(
  117. AttrSource(AttrSource(ImportSource("torch"), "utils"), "_pytree"),
  118. "SUPPORTED_NODES",
  119. )
  120. class FunctionSpec:
  121. def __init__(self, func: FunctionType) -> None:
  122. code = func.__code__
  123. vn = code.co_varnames
  124. self.posonly_count = code.co_posonlyargcount
  125. self.arg_count = code.co_argcount
  126. self.kwonly_count = code.co_kwonlyargcount
  127. self.posonly_names = vn[: self.posonly_count]
  128. self.pos_or_kw_names = vn[self.posonly_count : self.arg_count]
  129. self.all_pos_names = self.posonly_names + self.pos_or_kw_names
  130. self.kwonly_names = vn[self.arg_count : self.arg_count + self.kwonly_count]
  131. off = self.arg_count + self.kwonly_count
  132. self.varargs_name = vn[off] if code.co_flags & CO_VARARGS else None
  133. off += 1 if self.varargs_name else 0
  134. self.varkw_name = vn[off] if code.co_flags & CO_VARKEYWORDS else None
  135. def update_defaults(self, func: FunctionType) -> None:
  136. # Defaults can change from function call to function call. So re-update
  137. # them on every call.
  138. self.defaults = func.__defaults__ or ()
  139. self.kwdefaults = func.__kwdefaults__ or {}
  140. # Map positional-default names → their index in self.defaults
  141. self.pos_default_map = dict(
  142. zip(self.all_pos_names[-len(self.defaults) :], range(len(self.defaults)))
  143. )
  144. def _get_spec(func: FunctionType) -> FunctionSpec:
  145. spec = _spec_cache.get(func)
  146. if spec is None:
  147. spec = FunctionSpec(func)
  148. _spec_cache[func] = spec
  149. return spec
  150. def bind_args_cached(
  151. func: FunctionType,
  152. tx: "InstructionTranslator",
  153. fn_source: Source | None,
  154. args: Sequence[Any],
  155. kwargs: dict[str, Any],
  156. ) -> dict[str, VariableTracker]:
  157. spec = _get_spec(func)
  158. # Fast path: simple positional-only, no defaults, no varargs/varkw
  159. # This is the common case for small utility functions called repeatedly.
  160. if (
  161. len(args) == spec.arg_count
  162. and not func.__defaults__
  163. and not kwargs
  164. and not spec.varargs_name
  165. and not spec.varkw_name
  166. and not spec.kwonly_names
  167. ):
  168. return {
  169. name: wrap_bound_arg(tx, args[i])
  170. for i, name in enumerate(spec.all_pos_names)
  171. }
  172. # Full path with all features
  173. spec.update_defaults(func)
  174. ba = {}
  175. rem_kw = dict(kwargs)
  176. # 1) Bind all positional (pos-only + pos-or-kw)
  177. # 1.1) Apply pos-defaults first (maybe overridden later)
  178. for name, idx in spec.pos_default_map.items():
  179. default_source = None
  180. if fn_source and not (
  181. ConstantVariable.is_literal(spec.defaults[idx])
  182. and config.skip_guards_on_constant_func_defaults
  183. ):
  184. default_source = DefaultsSource(fn_source, idx)
  185. ba[name] = wrap_bound_arg(tx, spec.defaults[idx], default_source)
  186. # 1.2) Fill in provided positional args
  187. for i, name in enumerate(spec.all_pos_names):
  188. if i < len(args):
  189. # Maybe override pos-defaults applied above
  190. ba[name] = wrap_bound_arg(tx, args[i])
  191. elif name in rem_kw and (
  192. # `kwargs` can have the same key as a pos-only arg `name`.
  193. # If this case happens, we should not consume the `name` here and
  194. # keep it in `kwargs`:
  195. # >>> def fn(a, /, **kwargs): return (a, kwargs)
  196. # >>> fn(1, a=2)
  197. # (1, {'a': 2})
  198. name not in spec.posonly_names
  199. ):
  200. # Maybe override pos-defaults applied above
  201. ba[name] = wrap_bound_arg(tx, rem_kw.pop(name))
  202. elif name not in ba:
  203. raise TypeError(f"missing required positional argument: {name}")
  204. # 2) *args
  205. extra = args[len(spec.all_pos_names) :]
  206. if spec.varargs_name:
  207. ba[spec.varargs_name] = wrap_bound_arg(tx, tuple(extra))
  208. elif extra:
  209. raise TypeError(
  210. f"Too many positional arguments: got {len(args)}, expected {len(spec.all_pos_names)}"
  211. )
  212. # 3) Keyword-only
  213. for name in spec.kwonly_names:
  214. if name in rem_kw:
  215. ba[name] = wrap_bound_arg(tx, rem_kw.pop(name))
  216. elif name in spec.kwdefaults:
  217. kwdefault_source = None
  218. if fn_source:
  219. kwdefault_source = DefaultsSource(fn_source, name, is_kw=True)
  220. ba[name] = wrap_bound_arg(tx, spec.kwdefaults[name], kwdefault_source)
  221. else:
  222. raise TypeError(f"Missing required keyword-only argument: {name}")
  223. # 4) **kwargs
  224. if spec.varkw_name:
  225. ba[spec.varkw_name] = wrap_bound_arg(tx, rem_kw)
  226. elif rem_kw:
  227. raise TypeError(f"Unexpected keyword arguments: {list(rem_kw)}")
  228. return ba
  229. def wrap_bound_arg(
  230. tx: "InstructionTranslator", val: Any, source: Source | None = None
  231. ) -> VariableTracker:
  232. # Source propagation is best effort since not every object we encounter has a source to begin with.
  233. if isinstance(val, VariableTracker):
  234. return val
  235. elif not source:
  236. return VariableTracker.build(tx, val)
  237. else:
  238. # Create a lazy variable to avoid guarding on __defaults__ unless really
  239. # needed.
  240. return variables.LazyVariableTracker.create(val, source)
  241. def wrap_args_kwargs(tx: "InstructionTranslator", result: dict[str, Any]) -> None:
  242. for k, v in list(result.items()):
  243. if isinstance(v, (tuple, dict)):
  244. # args/kwargs
  245. result[k] = wrap_bound_arg(tx, v)
  246. def init_cellvars(
  247. parent: "InstructionTranslator",
  248. result: dict[str, VariableTracker],
  249. code: types.CodeType,
  250. ) -> None:
  251. """
  252. Update `result` to add mapping from local name to new cells created
  253. directly by `code`, or update SideEffects in `parent` if the a local cell is
  254. already in `result` (cell argument).
  255. """
  256. side_effects = parent.output.side_effects
  257. for name in code.co_cellvars:
  258. new_cell = side_effects.track_cell_new()
  259. if name in result:
  260. # This handles when a function argument is a cell (e.g., captured by
  261. # a nested func). See `MAKE_CELL` bytecode for more info.
  262. side_effects.store_cell(new_cell, result.pop(name))
  263. result[name] = new_cell
  264. def _create_nested_fn(
  265. code: types.CodeType,
  266. f_globals: dict[str, Any],
  267. name: str,
  268. defaults: tuple[object, ...] | None,
  269. closure: tuple[CellType] | None,
  270. kwdefaults: dict[str, Any] | None,
  271. annotations: dict[str, Any] | None,
  272. ) -> types.FunctionType:
  273. from types import FunctionType
  274. func = FunctionType(code, f_globals, name, defaults, closure)
  275. func.__kwdefaults__ = kwdefaults
  276. if isinstance(annotations, tuple):
  277. from itertools import pairwise
  278. annotations = dict(pairwise(annotations))
  279. # TypeError: __annotations__ must be set to a dict object
  280. assert annotations is None or isinstance(annotations, dict)
  281. func.__annotations__ = annotations # type: ignore[assignment]
  282. return func
  283. fn_known_dunder_attrs = {
  284. "__annotations__",
  285. "__defaults__",
  286. "__kwdefaults__",
  287. "__code__",
  288. "__globals__",
  289. "__closure__",
  290. "__doc__",
  291. }
  292. def fn_var_getattr(
  293. tx: "InstructionTranslator", fn: object, source: Source | None, name: str
  294. ) -> VariableTracker:
  295. source = source and AttrSource(source, name)
  296. if source and name == "__annotations__":
  297. # We get a large number of silly guards from annotations from inspect
  298. # module. Changing annotations is rare, and it impacting the extracted
  299. # graph is even rarer. So skip guards.
  300. source = SkipGuardSource(source)
  301. subobj = None
  302. try:
  303. subobj = inspect.getattr_static(fn, name)
  304. except AttributeError:
  305. # function does not have a __getattr__ or __getattribute__ method,
  306. # so we can safely assume that this attribute is absent
  307. raise_observed_exception(AttributeError, tx)
  308. # Special handling for known dunder attributes
  309. if name in fn_known_dunder_attrs:
  310. subobj = getattr(fn, name)
  311. if source:
  312. return variables.LazyVariableTracker.create(subobj, source)
  313. return VariableTracker.build(tx, subobj)
  314. class BaseUserFunctionVariable(VariableTracker):
  315. def get_filename(self) -> str:
  316. return self.get_code().co_filename
  317. def get_name(self) -> str:
  318. return self.get_code().co_name
  319. def get_globals(self) -> dict[str, Any]:
  320. raise NotImplementedError
  321. def get_code(self) -> types.CodeType:
  322. raise NotImplementedError
  323. def has_self(self) -> bool:
  324. raise NotImplementedError
  325. def call_function(
  326. self,
  327. tx: "InstructionTranslator",
  328. args: Sequence[VariableTracker],
  329. kwargs: dict[str, VariableTracker],
  330. ) -> VariableTracker:
  331. # Ignore patch_track_step_called from torch/optim/lr_scheduler.py - it just patches
  332. # the optimizer.step method and we don't need to trace it
  333. if (
  334. self.get_name() == "patch_track_step_called"
  335. and self.get_filename().endswith("torch/optim/lr_scheduler.py")
  336. ):
  337. return CONSTANT_VARIABLE_NONE
  338. return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) # type: ignore[attr-defined]
  339. def call_obj_hasattr(
  340. self, tx: "InstructionTranslator", name: str
  341. ) -> ConstantVariable:
  342. result = False
  343. try:
  344. result = hasattr(self.get_function(), name) # type: ignore[attr-defined]
  345. except NotImplementedError:
  346. if name == "__name__" and isinstance(self, NestedUserFunctionVariable):
  347. result = True
  348. return variables.ConstantVariable.create(result)
  349. def closure_vars(self, tx: "InstructionTranslator") -> dict[str, VariableTracker]:
  350. return {}
  351. # Override to set whether or not nested graph breaks should be allowed
  352. # if we create an inlining tx for this BaseUserFunctionVariable.
  353. # See symbolic_convert.py for where this function is called.
  354. def should_allow_nested_graph_breaks(self) -> bool:
  355. return True
  356. class UserFunctionVariable(BaseUserFunctionVariable):
  357. """Some unsupported user-defined global function"""
  358. _nonvar_fields = {
  359. "fn",
  360. "is_constant",
  361. *BaseUserFunctionVariable._nonvar_fields,
  362. }
  363. _TREE_MAP_MODULES = frozenset(
  364. {
  365. "optree",
  366. "optree.ops",
  367. "torch.utils._pytree",
  368. "torch.utils._cxx_pytree",
  369. }
  370. )
  371. @classmethod
  372. def create_with_source(cls, value: Any, source: Any) -> "UserFunctionVariable":
  373. install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH))
  374. return cls(value, source=source)
  375. def __init__(
  376. self,
  377. fn: types.FunctionType | torch.jit.ScriptFunction, # type: ignore[type-arg]
  378. is_constant: bool = False,
  379. **kwargs: Any,
  380. ) -> None:
  381. super().__init__(**kwargs)
  382. if getattr(fn, "_dynamo_marked_constant", False):
  383. # This method should be treated as a constant for the purposes of compilation
  384. self.is_constant = True
  385. else:
  386. self.is_constant = False
  387. # TODO putting this here to avoid duplication, because we could hit this
  388. # from several paths (e.g., SuperVariable or `var_getattr`s).
  389. if not isinstance(fn, (types.FunctionType, torch.jit.ScriptFunction)):
  390. unimplemented(
  391. gb_type="can't handle functions not implemented in python ",
  392. context=f"{fn}",
  393. explanation="Dynamo can only handle functions defined in python",
  394. hints=[
  395. "Move usage of this function out of `torch.compile` region",
  396. *graph_break_hints.INFERENCE_MODE,
  397. ],
  398. )
  399. # TODO(anijain2305) - Replace directly calling UserFunctionVariable with
  400. # VariableBuilder, which handles the wrapping of _torchdynamo_inline.
  401. # unpack @torch._dynamo.optimize()(fn) wrapped function
  402. fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
  403. self.fn = fn
  404. def as_python_constant(self) -> Any:
  405. if istype(self, UserFunctionVariable):
  406. return self.fn
  407. # subclasses (such as methods) usually aren't a constant
  408. return super().as_python_constant()
  409. def self_args(self) -> list[VariableTracker]:
  410. return []
  411. def get_function(self) -> types.FunctionType:
  412. return self.fn
  413. def get_code(self) -> types.CodeType:
  414. return self.fn.__code__
  415. def python_type(self) -> type:
  416. return types.FunctionType
  417. def has_self(self) -> bool:
  418. return getattr(self.fn, "__self__", None) is not None
  419. def get_globals(self) -> dict[str, Any]:
  420. return self.fn.__globals__
  421. def get_source(self) -> Source:
  422. source = self.source
  423. if source and isinstance(self, variables.UserMethodVariable):
  424. source = self.source_fn # type: ignore[assignment]
  425. return source # type: ignore[return-value]
  426. def bind_args(
  427. self,
  428. parent: "InstructionTranslator",
  429. args: Sequence[VariableTracker],
  430. kwargs: dict[str, VariableTracker],
  431. ) -> dict[str, VariableTracker]:
  432. """
  433. Assume `args` and `kwargs` are VariableTracker arguments for a call to
  434. this function, create new bindings for initial locals.
  435. """
  436. assert not self.is_constant
  437. fn: types.FunctionType = self.fn
  438. if not isinstance(fn, FunctionType):
  439. raise TypeError("Only supports regular Python functions.")
  440. root_tx = parent.output.root_tx
  441. source = self.get_source()
  442. result = bind_args_cached(fn, root_tx, source, args, kwargs) # type: ignore[arg-type]
  443. init_cellvars(parent, result, fn.__code__)
  444. closure = self.fn.__closure__ or ()
  445. assert len(closure) == len(self.fn.__code__.co_freevars)
  446. for idx, name, cell in zip(
  447. itertools.count(), self.fn.__code__.co_freevars, closure
  448. ):
  449. # TODO refactor these 3 branches.
  450. side_effects = parent.output.side_effects
  451. if cell in side_effects:
  452. cell_var = side_effects[cell]
  453. elif source:
  454. closure_cell = GetItemSource(ClosureSource(source), idx)
  455. closure_cell_contents = CellContentsSource(
  456. closure_cell, "cell_contents", freevar_name=name
  457. )
  458. try:
  459. contents_var = VariableTracker.build(
  460. parent, cell.cell_contents, closure_cell_contents
  461. )
  462. except ValueError:
  463. # Cell has not yet been assigned
  464. contents_var = variables.DeletedVariable()
  465. cell_var = side_effects.track_cell_existing(
  466. closure_cell, cell, contents_var
  467. )
  468. else:
  469. # TODO figure out why source isn't available here, and whether
  470. # we can fix that and remove this branch.
  471. try:
  472. contents_var = VariableTracker.build(parent, cell.cell_contents)
  473. except ValueError:
  474. # Cell has not yet been assigned
  475. contents_var = variables.DeletedVariable()
  476. cell_var = side_effects.track_cell_existing(None, cell, contents_var)
  477. result[name] = cell_var
  478. return result
  479. def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
  480. if name in cmp_name_to_op_mapping:
  481. return variables.GetAttrVariable(self, name)
  482. source = self.get_source()
  483. return fn_var_getattr(tx, self.fn, source, name)
  484. def call_obj_hasattr(
  485. self, tx: "InstructionTranslator", name: str
  486. ) -> ConstantVariable:
  487. result = hasattr(self.fn, name)
  488. return variables.ConstantVariable.create(result)
  489. def call_function(
  490. self,
  491. tx: "InstructionTranslator",
  492. args: Sequence[VariableTracker],
  493. kwargs: dict[str, VariableTracker],
  494. ) -> VariableTracker:
  495. # Handle patch_dynamo_config call
  496. if self.fn is torch._dynamo.patch_dynamo_config:
  497. try:
  498. args_const = [arg.as_python_constant() for arg in args]
  499. kwargs_const = {
  500. key: val.as_python_constant() for key, val in kwargs.items()
  501. }
  502. changes = torch._dynamo.patch_dynamo_config(
  503. *args_const, **kwargs_const
  504. ).changes
  505. return variables.DynamoConfigPatchVariable(changes)
  506. except AsPythonConstantNotImplementedError as e:
  507. raise RuntimeError(
  508. "Cannot convert patch_dynamo_config args/kwargs to constants. "
  509. "Please fix your call to patch_dynamo_config by using simpler inputs. "
  510. f"args: {args}, kwargs: {kwargs}"
  511. ) from e
  512. elif self.fn is torch._dynamo.error_on_graph_break:
  513. try:
  514. bound = inspect.signature(self.fn).bind(*args, **kwargs)
  515. error_on_graph_break = bound.arguments[
  516. "error_on_graph_break"
  517. ].as_python_constant()
  518. assert isinstance(error_on_graph_break, bool)
  519. return variables.ErrorOnGraphBreakVariable(error_on_graph_break)
  520. except Exception as e:
  521. raise RuntimeError(
  522. "Improper error_on_graph_break() call. Please fix your call to error_on_graph_break(). "
  523. f"args: {args}, kwargs: {kwargs}"
  524. ) from e
  525. # Handle a `nonstrict_trace(fn)` call
  526. elif self.fn is torch._dynamo.nonstrict_trace:
  527. bound = inspect.signature(self.fn).bind(*args, **kwargs)
  528. fn_var = bound.args[0]
  529. if not isinstance(fn_var, BaseUserFunctionVariable):
  530. typ = fn_var.python_type()
  531. msg = f"`nonstrict_trace` expects a callable, but got value of type <{typ.__name__}>"
  532. unimplemented(
  533. gb_type="TypeError from user code",
  534. context=f"call_function({self.value}, {args}, {kwargs})", # type: ignore[attr-defined]
  535. explanation=msg,
  536. hints=[
  537. *graph_break_hints.USER_ERROR,
  538. ],
  539. )
  540. if not isinstance(fn_var, UserFunctionVariable):
  541. fn_name = fn_var.get_name()
  542. msg = f"Applying `nonstrict_trace` to function <{fn_name}>; however, `nonstrict_trace` currently requires the function to be defined outside `torch.compile` region." # noqa: B950
  543. unimplemented(
  544. gb_type="Limitation of `nonstrict_trace",
  545. context=f"{self}",
  546. explanation=msg,
  547. hints=[
  548. f"make sure definition of {fn_name} is outside ",
  549. "`torch.compile` region",
  550. ],
  551. )
  552. fn = fn_var.fn
  553. return variables.TorchInGraphFunctionVariable(
  554. fn, kind=variables.torch.AllowInGraphKind.NONSTRICT_TRACE
  555. )
  556. if self.is_constant:
  557. return invoke_and_store_as_constant(
  558. tx, self.fn, self.get_name(), args, kwargs
  559. )
  560. if (
  561. not tx.output.current_tracer.unsafe_allow_externally_visible_side_effects
  562. and self.fn
  563. is torch._dynamo.utils._disable_side_effect_safety_checks_for_current_subtracer
  564. ):
  565. with torch._dynamo.side_effects.allow_externally_visible_side_effects_in_subtracer(
  566. tx
  567. ):
  568. return super().call_function(tx, args, kwargs)
  569. if (
  570. getattr(tx.output.current_tracer, "description", None)
  571. == "torch.utils.checkpoint.checkpoint"
  572. and not tx.output.current_tracer.allow_side_effects_in_hop
  573. ):
  574. try:
  575. from torch.distributed.fsdp._fully_shard._fsdp_state import FSDPState
  576. except Exception:
  577. FSDPState = None # type: ignore[assignment, misc]
  578. if FSDPState is not None and self.fn in [
  579. FSDPState._pre_forward,
  580. FSDPState._post_forward,
  581. ]:
  582. with torch._dynamo.side_effects.allow_side_effects_in_hop(tx):
  583. return super().call_function(tx, args, kwargs)
  584. tree_map_result = self._maybe_call_tree_map_fastpath(tx, args, kwargs)
  585. if tree_map_result is not None:
  586. return tree_map_result
  587. return super().call_function(tx, args, kwargs)
  588. def _maybe_call_tree_map_fastpath(
  589. self,
  590. tx: "InstructionTranslator",
  591. args: Sequence[VariableTracker],
  592. kwargs: dict[str, VariableTracker],
  593. ) -> VariableTracker | None:
  594. rewrite = self._rewrite_tree_map_only_call(tx, args, kwargs)
  595. if rewrite is not None:
  596. tree_map_fn, tree_map_args, tree_map_kwargs = rewrite
  597. else:
  598. tree_map_fn = self
  599. tree_map_args = args
  600. tree_map_kwargs = kwargs
  601. if not (
  602. isinstance(tree_map_fn, UserFunctionVariable)
  603. and tree_map_fn._is_tree_map_function()
  604. and not ({*tree_map_kwargs} - _SUPPORTED_TREE_MAP_KWARGS)
  605. and len(tree_map_args) >= 2
  606. ):
  607. return None
  608. map_fn = tree_map_args[0]
  609. first_tree = tree_map_args[1]
  610. rest = tree_map_args[2:]
  611. return first_tree.call_tree_map(
  612. tx,
  613. tree_map_fn,
  614. map_fn,
  615. rest,
  616. tree_map_kwargs,
  617. )
  618. def _is_tree_map_function(self) -> bool:
  619. return (
  620. getattr(self.fn, "__name__", None) == "tree_map"
  621. and getattr(self.fn, "__module__", None) in self._TREE_MAP_MODULES
  622. )
  623. def _is_tree_map_only_function(self) -> bool:
  624. return (
  625. getattr(self.fn, "__name__", None) == "tree_map_only"
  626. and getattr(self.fn, "__module__", None) in self._TREE_MAP_MODULES
  627. )
  628. def _rewrite_tree_map_only_call(
  629. self,
  630. tx: "InstructionTranslator",
  631. args: Sequence[VariableTracker],
  632. kwargs: dict[str, VariableTracker],
  633. ) -> (
  634. tuple[
  635. "UserFunctionVariable",
  636. Sequence[VariableTracker],
  637. dict[str, VariableTracker],
  638. ]
  639. | None
  640. ):
  641. if not self._is_tree_map_only_function():
  642. return None
  643. if len(args) != 3:
  644. return None
  645. if {*kwargs} - _TREE_MAP_ONLY_SUPPORTED_KWARGS:
  646. return None
  647. type_selector, map_fn, tree_arg = args
  648. allowed_types = self._extract_tree_map_only_types(type_selector)
  649. if allowed_types is None:
  650. return None
  651. tree_map_callable = self._lookup_tree_map_function()
  652. if tree_map_callable is None:
  653. return None
  654. wrapped_map_fn = TreeMapOnlyFunctionVariable(
  655. allowed_types,
  656. map_fn,
  657. source=getattr(map_fn, "source", None),
  658. )
  659. tree_map_variable = variables.UserFunctionVariable(tree_map_callable)
  660. return tree_map_variable, [wrapped_map_fn, tree_arg], dict(kwargs)
  661. def _lookup_tree_map_function(self) -> types.FunctionType | None:
  662. module_name = getattr(self.fn, "__module__", None)
  663. if not module_name:
  664. return None
  665. module = sys.modules.get(module_name)
  666. if module is None:
  667. return None
  668. tree_map = getattr(module, "tree_map", None)
  669. if isinstance(tree_map, types.FunctionType):
  670. return tree_map
  671. return None
  672. def _extract_tree_map_only_types(
  673. self, selector: VariableTracker
  674. ) -> tuple[type, ...] | None:
  675. if not selector.is_python_constant():
  676. return None
  677. try:
  678. raw_value = selector.as_python_constant()
  679. except NotImplementedError:
  680. return None
  681. flattened = self._flatten_type_spec(raw_value)
  682. if not flattened:
  683. return None
  684. if not all(isinstance(typ, type) for typ in flattened):
  685. return None
  686. return tuple(dict.fromkeys(flattened))
  687. def _flatten_type_spec(self, value: Any) -> list[type] | None:
  688. if isinstance(value, type):
  689. return [value]
  690. if isinstance(value, tuple):
  691. collected: list[type] = []
  692. for entry in value:
  693. flat = self._flatten_type_spec(entry)
  694. if flat is None:
  695. return None
  696. collected.extend(flat)
  697. return collected
  698. union_type = getattr(types, "UnionType", None)
  699. if union_type is not None and isinstance(value, union_type):
  700. collected = []
  701. for entry in value.__args__:
  702. flat = self._flatten_type_spec(entry)
  703. if flat is None:
  704. return None
  705. collected.extend(flat)
  706. return collected
  707. return None
  708. def is_python_hashable(self) -> Literal[True]:
  709. return True
  710. def get_python_hash(self) -> int:
  711. return hash(self.fn)
  712. def is_python_equal(self, other: object) -> bool:
  713. return isinstance(other, variables.UserFunctionVariable) and self.fn is other.fn
  714. class InspectSignatureVariable(UserFunctionVariable):
  715. """
  716. Variable tracker for inspect.signature with caching support.
  717. inspect.Signature is expensive to trace. When inspect.signature is called
  718. repeatedly on the same function during tracing, we cache the result to avoid
  719. retracing the signature construction each time. Although this is different
  720. from CPython behavior, it is safe to do so because inspect.signature does
  721. not change across different calls to the same function.
  722. """
  723. def call_function(
  724. self,
  725. tx: "InstructionTranslator",
  726. args: Sequence[VariableTracker],
  727. kwargs: dict[str, VariableTracker],
  728. ) -> VariableTracker:
  729. # Fast path: cache results for repeated calls on the same function
  730. if len(args) == 1 and not kwargs:
  731. target_arg = args[0]
  732. cache_key = None
  733. if isinstance(target_arg, (UserFunctionVariable, UserMethodVariable)):
  734. cache_key = target_arg.get_function()
  735. if cache_key is not None:
  736. if cache_key in tx.output.signature_cache:
  737. return tx.output.signature_cache[cache_key]
  738. result = super().call_function(tx, args, kwargs)
  739. tx.output.signature_cache[cache_key] = result
  740. return result
  741. return super().call_function(tx, args, kwargs)
  742. class TreeMapOnlyFunctionVariable(BaseUserFunctionVariable):
  743. _nonvar_fields = {
  744. "allowed_types",
  745. *BaseUserFunctionVariable._nonvar_fields,
  746. }
  747. def __init__(
  748. self,
  749. allowed_types: tuple[type, ...],
  750. map_fn: VariableTracker,
  751. **kwargs: Any,
  752. ) -> None:
  753. super().__init__(**kwargs)
  754. self.allowed_types = allowed_types
  755. self.map_fn = map_fn
  756. def python_type(self) -> type:
  757. return FunctionType
  758. def _matches_allowed_type(self, node: VariableTracker) -> bool:
  759. try:
  760. node_type = node.python_type()
  761. except NotImplementedError:
  762. return False
  763. return any(issubclass(node_type, allowed) for allowed in self.allowed_types)
  764. def call_function(
  765. self,
  766. tx: "InstructionTranslator",
  767. args: Sequence[VariableTracker],
  768. kwargs: dict[str, VariableTracker],
  769. ) -> VariableTracker:
  770. if not args:
  771. return self.map_fn.call_function(tx, args, kwargs)
  772. leaf = args[0]
  773. if self._matches_allowed_type(leaf):
  774. return self.map_fn.call_function(tx, args, kwargs)
  775. if len(args) != 1 or kwargs:
  776. # Defer to the original map function so we fall back to normal
  777. # tracing instead of triggering a graph break.
  778. return self.map_fn.call_function(tx, args, kwargs)
  779. return leaf
  780. class BuiltinMethodVariable(BaseUserFunctionVariable):
  781. def __init__(
  782. self, fn: types.BuiltinMethodType, is_constant: bool = False, **kwargs: Any
  783. ) -> None:
  784. super().__init__(**kwargs)
  785. assert isinstance(fn, types.BuiltinMethodType)
  786. self.fn = fn
  787. @staticmethod
  788. def is_supported_builtin_method(obj: Any) -> bool:
  789. method_self = obj.__self__
  790. method_name = obj.__name__
  791. # TODO(anijain2305) - Add support for more builtin methods
  792. # Supports tuple.__new__ and frozenset({....}).__contains__
  793. return (method_self is tuple and method_name == "__new__") or (
  794. type(method_self) is frozenset and method_name == "__contains__"
  795. )
  796. def call_function(
  797. self,
  798. tx: "InstructionTranslator",
  799. args: Sequence[VariableTracker],
  800. kwargs: dict[str, VariableTracker],
  801. ) -> VariableTracker:
  802. method_self = self.fn.__self__
  803. name = self.fn.__name__
  804. obj_source = self.source and AttrSource(self.source, "__self__")
  805. obj_vt = VariableTracker.build(tx, method_self, obj_source, realize=True)
  806. return obj_vt.call_method(tx, name, args, kwargs)
  807. class LocalGeneratorObjectVariable(VariableTracker):
  808. def __init__(
  809. self,
  810. code: types.CodeType,
  811. f_globals: dict[str, Any],
  812. inline_tracer: "InliningGeneratorInstructionTranslator",
  813. **kwargs: Any,
  814. ) -> None:
  815. super().__init__(**kwargs)
  816. self.code = code
  817. self.f_globals = f_globals
  818. self.inline_tracer = inline_tracer
  819. def get_code(self) -> types.CodeType:
  820. return self.code
  821. def get_filename(self) -> str:
  822. return self.get_code().co_filename
  823. def get_name(self) -> str:
  824. return self.get_code().co_name
  825. def get_function(self) -> Never:
  826. raise NotImplementedError("get_function")
  827. def has_self(self) -> bool:
  828. return False
  829. def __name__(self) -> str:
  830. return self.get_name()
  831. def __str__(self) -> str:
  832. return f"{self.__class__.__name__}({self.get_name()})"
  833. __repr__ = __str__
  834. def reconstruct(self, codegen: "PyCodegen") -> None:
  835. from torch._dynamo.side_effects import disallow_side_effects_in_generator
  836. from torch._dynamo.symbolic_convert import (
  837. save_and_restart_speculation_log,
  838. temporarely_allow_writes_to_output_graph,
  839. )
  840. tx = codegen.tx
  841. save = save_and_restart_speculation_log(tx)
  842. disallow = disallow_side_effects_in_generator(tx)
  843. temp = temporarely_allow_writes_to_output_graph(tx)
  844. with save, disallow, temp:
  845. tracer = self.inline_tracer
  846. if not tracer.generator_exhausted:
  847. self.remaining_items = self.force_unpack_var_sequence(tx)
  848. variables.ListIteratorVariable(self.remaining_items).reconstruct(codegen)
  849. def get_globals(self) -> dict[str, Any]:
  850. return self.f_globals
  851. def python_type(self) -> type:
  852. return types.GeneratorType
  853. def next_variable(self, tx: "InstructionTranslatorBase") -> VariableTracker:
  854. tracer = self.inline_tracer
  855. if self._is_generator_exhausted():
  856. raise_observed_exception(StopIteration, tx)
  857. try:
  858. # Hierarchically, tx can be seen as the parent of the inline tracer
  859. # created on call_function. Any exception needs to be propagated to tx
  860. # for Dynamo to behave correctly
  861. return tracer.inline_call_()
  862. except ObservedException as e:
  863. tracer.generator_exhausted = True
  864. raise e
  865. except InfiniteGeneratorError:
  866. # test/dynamo/test_misc.py::test_iterator_limit
  867. unimplemented(
  868. gb_type="infinite generator detected",
  869. context="",
  870. explanation="Dynamo traced the YIELD_VALUE bytecode too many times. This could mean "
  871. "that we have attempted to trace an infinite generator.",
  872. hints=[
  873. f"If you are sure that your generator is not infinite, please report a bug at {PT2_ISSUE_TRACKER_URL}.",
  874. *graph_break_hints.USER_ERROR,
  875. ],
  876. )
  877. except Unsupported as e:
  878. torch._dynamo.eval_frame.skip_code(self.get_code())
  879. e.skip_frame = True
  880. if not tx.one_graph and not tx.error_on_graph_break:
  881. e.msg += "\n\nSkipping frame due to graph break in a generator's next() call."
  882. raise
  883. def call_obj_hasattr(
  884. self, tx: "InstructionTranslator", name: str
  885. ) -> ConstantVariable:
  886. if name in self.python_type().__dict__:
  887. return ConstantVariable.create(True)
  888. return ConstantVariable.create(False)
  889. def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool:
  890. return False
  891. def has_force_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool:
  892. return True
  893. def force_unpack_var_sequence(
  894. self, tx: "InstructionTranslatorBase"
  895. ) -> list[VariableTracker]:
  896. result: list[VariableTracker] = []
  897. self.force_apply_to_var_sequence(tx, result.append)
  898. return result
  899. def force_apply_to_var_sequence(
  900. self, tx: "InstructionTranslatorBase", fn: Callable[[VariableTracker], Any]
  901. ) -> None:
  902. while True:
  903. try:
  904. fn(self.next_variable(tx))
  905. except ObservedUserStopIteration:
  906. handle_observed_exception(tx)
  907. break
  908. # no nested graph breaks in generators
  909. def should_allow_nested_graph_breaks(self) -> Literal[False]:
  910. return False
  911. def _setup_exception(
  912. self, tx: "InstructionTranslator", exc: VariableTracker
  913. ) -> None:
  914. tracer = self.inline_tracer
  915. try:
  916. tracer._raise_exception_variable(exc)
  917. except ObservedException as e:
  918. # if no handler is available (i.e. user code doesn't catch it), the
  919. # exception is raised again.
  920. tracer.exception_handler(e)
  921. def _is_generator_just_started(self) -> bool:
  922. return self.inline_tracer is None or self.inline_tracer.instruction_pointer == 0
  923. def _is_generator_exhausted(self) -> bool:
  924. return getattr(self.inline_tracer, "generator_exhausted", False)
  925. def call_method(
  926. self,
  927. tx: "InstructionTranslator",
  928. name: str,
  929. args: list[VariableTracker],
  930. kwargs: dict[str, VariableTracker],
  931. ) -> VariableTracker:
  932. if name == "__next__":
  933. return self.next_variable(tx)
  934. elif name == "__iter__":
  935. # iter(gen) returns itself
  936. return self
  937. elif name == "send":
  938. # Sends a value into the generator function. Returns the next value
  939. # yielded by the generator, or raises StopIteration if the generator
  940. # exits without yielding another value
  941. if self._is_generator_just_started() and len(args):
  942. # can't send non-None value to a just-started generator
  943. # Test: GeneratorCPythonTests.test_send_non_none_to_new_gen
  944. if not all(arg.is_constant_none() for arg in args):
  945. raise_observed_exception(TypeError, tx)
  946. tracer = self.inline_tracer
  947. tracer.push_many(args)
  948. return self.next_variable(tx)
  949. elif name == "close":
  950. # * Raises a GeneratorExit at the point where the generator function was paused.
  951. # * If the generator function catches the exception and returns a
  952. # value, this value is returned from close() - Python 3.13+
  953. # * If the generator function is already closed, or raises GeneratorExit
  954. # (by not catching the exception), close() returns None.
  955. # * If the generator yields a value, a RuntimeError is raised.
  956. # * If the generator raises any other exception, it is propagated to the caller.
  957. # * If the generator has already exited due to an exception or normal
  958. # exit, close() returns None and has no other effect.
  959. # Return None if close is called on a just-started generator
  960. # See test GeneratorCloseCpythonTests::test_close_not_started
  961. tracer = self.inline_tracer
  962. if self._is_generator_just_started() or self._is_generator_exhausted():
  963. tracer.generator_exhausted = True
  964. return variables.CONSTANT_VARIABLE_NONE
  965. # Raise GeneratorExit to see if user code catches it. Any other exception
  966. # is propagated to the parent frame.
  967. try:
  968. self._setup_exception(
  969. tx, variables.ExceptionVariable(GeneratorExit, ())
  970. )
  971. # There's an extra block on Python 3.12+ to handle StopIteration
  972. # see: https://github.com/python/cpython/blob/8f93dd8a8f237b277abad20d566df90c5cbd7f1e/Objects/genobject.c#L394-L397
  973. #
  974. # 1 0 RETURN_GENERATOR
  975. # 2 POP_TOP
  976. # 4 RESUME 0
  977. # 2 6 LOAD_CONST 1 (1)
  978. # 8 YIELD_VALUE 1
  979. # 10 RESUME 1
  980. # 12 POP_TOP
  981. # 14 RETURN_CONST 0 (None)
  982. # >> 16 CALL_INTRINSIC_1 3 (INTRINSIC_STOPITERATION_ERROR)
  983. # 18 RERAISE 1
  984. # ExceptionTable:
  985. # 4 to 14 -> 16 [0] lasti
  986. if (
  987. sys.version_info >= (3, 12)
  988. and tracer.next_instruction.opname == "CALL_INTRINSIC_1"
  989. ):
  990. tracer.generator_exhausted = True
  991. return variables.CONSTANT_VARIABLE_NONE
  992. except ObservedGeneratorExit:
  993. # If it doesn't catch, we just return None, as per the text above
  994. tracer.generator_exhausted = True
  995. return variables.CONSTANT_VARIABLE_NONE
  996. try:
  997. # Raise RuntimeError if the generator yields any other value
  998. if self.next_variable(tx):
  999. raise_observed_exception(RuntimeError, tx)
  1000. except ObservedGeneratorExit:
  1001. tracer.generator_exhausted = True
  1002. return variables.CONSTANT_VARIABLE_NONE
  1003. except ObservedUserStopIteration:
  1004. # In Python 3.13+, one can capture GeneratorExit and return a value
  1005. # See test_generator.py::test_close_capture_GeneratorExit_return
  1006. # https://discuss.python.org/t/let-generator-close-return-stopiteration-value/24786/26
  1007. # https://github.com/python/cpython/pull/104771
  1008. assert tracer.symbolic_result is not None
  1009. return tracer.symbolic_result
  1010. elif name == "throw":
  1011. # * Raises an exception at the point where the generator was paused, and
  1012. # returns the next value yielded by the generator.
  1013. # * If the generator exits without yielding, raise StopIteration
  1014. # * If the generator function does not catch the passed-in exception,
  1015. # or raises a different exception, then that exception propagates to the caller.
  1016. # Setup the exception table and jump target in case of try...finally
  1017. tracer = self.inline_tracer
  1018. try:
  1019. # In Python 3.9, the exception is represented as a triple (typ, val, tb)
  1020. # In such cases, we re-raise the exception object given to avoid
  1021. # creating a new object, so that IS_OP works.
  1022. # See: https://github.com/pytorch/pytorch/pull/146496
  1023. self._setup_exception(tx, args[1] if len(args) == 3 else args[0])
  1024. except ObservedException: # noqa: TRY203
  1025. # propagate the exception back to the parent caller
  1026. raise
  1027. retval = self.next_variable(tx)
  1028. # The exception raised before is still active. We need to check the exception
  1029. # table one more time to find the next target. But why? Let's walk
  1030. # through an example and its generated bytecode: https://godbolt.org/z/ebdTbMv8M
  1031. #
  1032. # z = 0
  1033. # def whoo():
  1034. # global z
  1035. # z = 0
  1036. # try:
  1037. # yield 1
  1038. # except ValueError:
  1039. # yield 2
  1040. # finally:
  1041. # z += 1
  1042. # z += 10
  1043. #
  1044. # gen = whoo()
  1045. # next(gen)
  1046. # gen.throw(ValueError)
  1047. # print('z', z) -> z = 1
  1048. #
  1049. # ...
  1050. # >> 58 PUSH_EXC_INFO
  1051. #
  1052. # 8 60 LOAD_GLOBAL 2 (ValueError)
  1053. # 70 CHECK_EXC_MATCH
  1054. # 72 POP_JUMP_IF_FALSE 7 (to 88)
  1055. # 74 POP_TOP
  1056. #
  1057. # 9 76 LOAD_CONST 3 (2)
  1058. # 78 YIELD_VALUE 3 <------ ValueError is still active here
  1059. # 80 RESUME 1
  1060. # 82 POP_TOP
  1061. # 84 POP_EXCEPT
  1062. # 86 jump_backward 34 (to 20)
  1063. # ...
  1064. #
  1065. # ExceptionTable:
  1066. # 4 to 8 -> 124 [0] lasti
  1067. # 12 to 18 -> 58 [0]
  1068. # 20 to 56 -> 124 [0] lasti
  1069. # 58 to 82 -> 90 [1] lasti <------ move to 90
  1070. # 84 to 86 -> 96 [0]
  1071. # 88 to 88 -> 90 [1] lasti
  1072. # 90 to 94 -> 96 [0]
  1073. # 96 to 116 -> 118 [1] lasti
  1074. # 118 to 122 -> 124 [0] lasti
  1075. #
  1076. # In this scenario, a generator can yield after `throw()` is called. Even
  1077. # after the exception is raised a few lines above, it remains active
  1078. # within the `78 YIELD_VALUE` instruction. When the generator resumes
  1079. # after the second yield on instruction `80 RESUME`, we cannot simply
  1080. # return the control flow to the next instruction. Instead, one must
  1081. # check the exception table (or equivalent) to find the next target
  1082. # In this case, it says the instruction pointer must be moved to 90.
  1083. #
  1084. # Without this step, if we let the trace proceed to the next
  1085. # instruction, it would follow the control flow where the exception
  1086. # raised by `throw()` was handled and swallowed, potentially leading
  1087. # to incorrect behavior.
  1088. exc_type = type("__InternalThrowException", (Exception,), {})
  1089. try:
  1090. self._setup_exception(tx, variables.ExceptionVariable(exc_type, ()))
  1091. self.next_variable(tx)
  1092. except get_dynamo_observed_exception(exc_type):
  1093. # We should get back the exception raised before.
  1094. pass
  1095. else:
  1096. raise_observed_exception(RuntimeError, tracer)
  1097. return retval
  1098. return super().call_method(tx, name, args, kwargs)
  1099. class ContextlibContextManagerLocalGeneratorObjectVariable(
  1100. LocalGeneratorObjectVariable
  1101. ):
  1102. """
  1103. .. note::
  1104. This is only used when the function is annotated with @contextlib.contextmanager
  1105. It is a special case of a generator function as we do not allow return a context manager
  1106. from a torch.compile function.
  1107. """
  1108. class LocalGeneratorFunctionVariable(BaseUserFunctionVariable):
  1109. """functions that behaves like iterators
  1110. .. note::
  1111. This is a wrapper around (Nested)UserFunctionVariable
  1112. """
  1113. def __init__(
  1114. self,
  1115. vt: BaseUserFunctionVariable,
  1116. *,
  1117. generator_cls: type = LocalGeneratorObjectVariable,
  1118. **kwargs: Any,
  1119. ) -> None:
  1120. super().__init__(**kwargs)
  1121. self.vt = vt
  1122. self.generator_cls = generator_cls
  1123. def __getattr__(self, name: str) -> Any:
  1124. if name in self.__class__.__dict__:
  1125. return getattr(self, name)
  1126. return getattr(self.vt, name)
  1127. # These need to be explicit so the custom __getattr__ doesn't fall back to the unimplemented base class version
  1128. def get_code(self) -> types.CodeType:
  1129. return self.vt.get_code()
  1130. def get_globals(self) -> dict[str, Any]:
  1131. return self.vt.get_globals()
  1132. def has_self(self) -> bool:
  1133. return self.vt.has_self()
  1134. def _build_inline_tracer(
  1135. self,
  1136. tx: "InstructionTranslatorBase",
  1137. args: list[VariableTracker],
  1138. kwargs: dict[str, VariableTracker],
  1139. ) -> "InliningInstructionTranslator":
  1140. from torch._dynamo.symbolic_convert import InliningInstructionTranslator
  1141. return InliningInstructionTranslator.build_inline_tracer(
  1142. tx,
  1143. self,
  1144. args,
  1145. kwargs,
  1146. )
  1147. def call_function(
  1148. self,
  1149. tx: "InstructionTranslator",
  1150. args: Sequence[VariableTracker],
  1151. kwargs: dict[str, VariableTracker],
  1152. ) -> VariableTracker:
  1153. if not is_generator(self.vt.get_code()):
  1154. unimplemented(
  1155. gb_type="non-generator contextlib.contextmanager",
  1156. context=str(self.vt.get_code()),
  1157. explanation="Cannot compile function decorated with `@contextlib.contextmanager` that is not a generator"
  1158. ", i.e. does not use `yield`",
  1159. hints=[
  1160. "Use `yield` in the function body instead of `return`.",
  1161. "Remove the `@contextlib.contextmanager` decorator.",
  1162. ],
  1163. )
  1164. inline_tracer = self._build_inline_tracer(tx, list(args), kwargs)
  1165. code = self.vt.get_code()
  1166. f_globals = self.vt.get_globals()
  1167. # calling a generator returns a generator object
  1168. return self.generator_cls(
  1169. code,
  1170. f_globals,
  1171. inline_tracer, # type: ignore[arg-type]
  1172. source=self.source,
  1173. )
  1174. class FunctionDecoratedByContextlibContextManagerVariable(
  1175. LocalGeneratorFunctionVariable
  1176. ):
  1177. """
  1178. .. note::
  1179. This is only used when the function is annotated with @contextlib.contextmanager
  1180. """
  1181. def __init__(self, vt: BaseUserFunctionVariable, **kwargs: Any) -> None:
  1182. super().__init__(
  1183. vt,
  1184. generator_cls=ContextlibContextManagerLocalGeneratorObjectVariable,
  1185. **kwargs,
  1186. )
  1187. def _build_inline_tracer(
  1188. self,
  1189. tx: "InstructionTranslatorBase",
  1190. args: list[VariableTracker],
  1191. kwargs: dict[str, VariableTracker],
  1192. ) -> "InliningGeneratorInstructionTranslator":
  1193. # NOTE: This only exists to not break support for context manager when
  1194. # config.enable_faithful_generator_behavior = False and
  1195. # config.enable_trace_contextlib = True. In case the former is false,
  1196. # Dynamo should still be able to trace through @contextmanager functions
  1197. tracer = super()._build_inline_tracer(tx, args, kwargs)
  1198. assert isinstance(
  1199. tracer,
  1200. torch._dynamo.symbolic_convert.InliningGeneratorInstructionTranslator,
  1201. )
  1202. tracer.is_generator_from_ctx_manager = True
  1203. return tracer
  1204. class UserMethodVariable(UserFunctionVariable):
  1205. """Some unsupported user-defined method"""
  1206. def __init__(
  1207. self,
  1208. fn: Callable[..., Any],
  1209. obj: VariableTracker,
  1210. source_fn: Source | None = None,
  1211. **kwargs: Any,
  1212. ) -> None:
  1213. super().__init__(fn=fn, **kwargs) # type: ignore[arg-type]
  1214. self.obj = obj
  1215. self.source_fn = source_fn
  1216. # Note on source and source_fn
  1217. # Be careful with `source` when delegating to UserFunctionVariable
  1218. # (base-class) methods. In this __init__, `source` is a *bound method*
  1219. # object, but the base class expects the underlying *function* object.
  1220. # One way is to simplly use `__func__` to unwrap it.
  1221. #
  1222. # For recursive dict-tag optimizations, it can be faster to fetch the
  1223. # function directly from `cls.__dict__`; that's why we pass on
  1224. # `source_fn`. Whenever it is possible to access the function from
  1225. # cls.__dict__, we pass that on to `source_fn`. Because bind_args
  1226. # operates on the unbound function, most guards should target
  1227. # `source_fn` rather than the original `source`.
  1228. if source_fn is None and kwargs.get("source") is not None:
  1229. self.source_fn = AttrSource(kwargs.get("source"), "__func__") # type: ignore[assignment, arg-type]
  1230. def __repr__(self) -> str:
  1231. return f"{self.__class__.__name__}({self.fn}, {self.obj})"
  1232. def self_args(self) -> list[VariableTracker]:
  1233. return [self.obj]
  1234. def python_type(self) -> type[types.MethodType]:
  1235. return types.MethodType
  1236. def call_function(
  1237. self,
  1238. tx: "InstructionTranslator",
  1239. args: Sequence[VariableTracker],
  1240. kwargs: dict[str, VariableTracker],
  1241. ) -> VariableTracker:
  1242. # NOTE this is to handle methods annotated by `nonstrict_trace`.
  1243. # a `nonstrict_trace`-ed function will be wrapped by
  1244. # `VariableTracker.build` and route to `TorchInGraphFunctionVariable`,
  1245. # but in the case of method, we manually wrap it with `UserMethodVariable`
  1246. # inside `UserDefinedObjectVariable.var_getattr`.
  1247. #
  1248. # We might be able to simplify this away by canonicalizing the
  1249. # function/method wrapping code paths.
  1250. from ..trace_rules import is_leaf_function, is_nonstrict_trace_callable
  1251. if is_nonstrict_trace_callable(self.fn):
  1252. call_args = [*self.self_args(), *args]
  1253. var = variables.TorchInGraphFunctionVariable(
  1254. self.fn, kind=variables.torch.AllowInGraphKind.NONSTRICT_TRACE
  1255. )
  1256. return var.call_function(tx, call_args, kwargs)
  1257. if is_leaf_function(self.fn):
  1258. call_args = [*self.self_args(), *args]
  1259. var = variables.TorchInGraphFunctionVariable(
  1260. self.fn, kind=variables.torch.AllowInGraphKind.LEAF_FUNCTION
  1261. )
  1262. return var.call_function(tx, call_args, kwargs)
  1263. # For nn.Module methods, redirecting to NNModuleVariable.call_method for optimized solution
  1264. # rather than simple inlining. E.g, putting `call_method` op in FX graph for `forward` method
  1265. # since we ensure `forward` of allowed modules can be traced by AOT safely.
  1266. # Note this is not only for allowed modules, as user customized modules can extend from
  1267. # allowed modules but using parent's `forward` method, which is also covered by this branch.
  1268. # If we are tracing the higher order op, we want Dynamo to step inside
  1269. # the module call so that Dynamo can see the underlying parameters and
  1270. # buffers and raise them as inputs to the graph. The is_root_tracer
  1271. # check bypasses the if condition for non-root tracers and directly
  1272. # calls the super().call_function at the end, which is basically
  1273. # equivalent of inlining the method.
  1274. if tx.output.is_root_tracer() and isinstance(
  1275. self.obj, variables.NNModuleVariable
  1276. ):
  1277. module_attr = getattr(self.fn, "__module__", "")
  1278. # inline torch.nn.utils.parametrize
  1279. if (
  1280. module_attr is not None
  1281. and module_attr.startswith("torch.nn.")
  1282. and module_attr != "torch.nn.utils.parametrize"
  1283. or self.is_constant
  1284. ):
  1285. return self.obj.call_method(
  1286. tx, self.fn.__name__, list(args), kwargs, constant=self.is_constant
  1287. )
  1288. elif (
  1289. _fsdp_param_group is not None
  1290. and self.fn is _fsdp_param_group.FSDPParamGroup.use_training_state # type: ignore[attr-defined]
  1291. ):
  1292. return variables.TorchCtxManagerClassVariable(self.fn).call_function(
  1293. tx, (self.obj, *args), kwargs
  1294. )
  1295. if self.is_constant:
  1296. fn = getattr(self.obj.value, self.fn.__name__) # type: ignore[attr-defined]
  1297. return invoke_and_store_as_constant(tx, fn, self.get_name(), args, kwargs)
  1298. return super().call_function(tx, args, kwargs)
  1299. def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
  1300. if name == "__self__":
  1301. return self.obj
  1302. if name == "__func__":
  1303. # We might have a better way to access the function object, this
  1304. # information is stored in self.source_fn, use that to construct the
  1305. # variable tracker.
  1306. return VariableTracker.build(tx, self.fn, self.source_fn) # type: ignore[arg-type]
  1307. return super().var_getattr(tx, name)
  1308. class WrappedUserMethodVariable(UserMethodVariable):
  1309. def __init__(
  1310. self,
  1311. wrapped: UserMethodVariable,
  1312. context: "ContextWrappingVariable",
  1313. **kwargs: Any,
  1314. ) -> None:
  1315. kwargs.pop("fn", None)
  1316. kwargs.pop("obj", None)
  1317. super().__init__(wrapped.fn, wrapped.obj, **kwargs)
  1318. self.wrapped = wrapped
  1319. self.context = context
  1320. def call_function(
  1321. self,
  1322. tx: "InstructionTranslator",
  1323. args: Sequence[VariableTracker],
  1324. kwargs: dict[str, VariableTracker],
  1325. ) -> VariableTracker:
  1326. self.context.enter(tx)
  1327. result = super().call_function(tx, args, kwargs)
  1328. self.context.exit(tx)
  1329. return result
  1330. def reconstruct(self, codegen: "PyCodegen") -> None:
  1331. codegen.add_push_null(lambda: codegen(self.context)) # type: ignore[arg-type]
  1332. codegen(self.wrapped)
  1333. codegen.extend_output(create_call_function(1, False))
  1334. class WrappedUserFunctionVariable(UserFunctionVariable):
  1335. def __init__(
  1336. self,
  1337. wrapped: UserFunctionVariable,
  1338. context: "ContextWrappingVariable",
  1339. **kwargs: Any,
  1340. ) -> None:
  1341. kwargs.pop("fn", None)
  1342. super().__init__(wrapped.fn, **kwargs)
  1343. self.wrapped = wrapped
  1344. self.context = context
  1345. def call_function(
  1346. self,
  1347. tx: "InstructionTranslator",
  1348. args: Sequence[VariableTracker],
  1349. kwargs: dict[str, VariableTracker],
  1350. ) -> VariableTracker:
  1351. self.context.enter(tx)
  1352. result = super().call_function(tx, args, kwargs)
  1353. self.context.exit(tx)
  1354. return result
  1355. def reconstruct(self, codegen: "PyCodegen") -> None:
  1356. codegen.add_push_null(lambda: codegen(self.context)) # type: ignore[arg-type]
  1357. codegen(self.wrapped)
  1358. codegen.extend_output(create_call_function(1, False))
  1359. def invoke_and_store_as_constant(
  1360. tx: "InstructionTranslator",
  1361. fn: Callable[..., Any],
  1362. name: str,
  1363. args: Sequence[VariableTracker],
  1364. kwargs: dict[str, VariableTracker],
  1365. ) -> VariableTracker:
  1366. def convert(x: VariableTracker) -> Any:
  1367. if x.is_tensor():
  1368. return cast("TensorVariable", x).get_real_value()
  1369. return x.as_python_constant()
  1370. args = [convert(x) for x in args]
  1371. kwargs = {k: convert(v) for k, v in kwargs.items()}
  1372. res = fn(*args, **kwargs)
  1373. return tx.output.register_attr_or_module(
  1374. res,
  1375. name,
  1376. source=ConstantSource(name),
  1377. )
  1378. class NestedUserFunctionVariable(BaseUserFunctionVariable):
  1379. _nonvar_fields = {
  1380. "f_globals",
  1381. *BaseUserFunctionVariable._nonvar_fields,
  1382. }
  1383. def __init__(
  1384. self,
  1385. fn_name: VariableTracker,
  1386. code: VariableTracker,
  1387. f_globals: dict[str, Any],
  1388. defaults: VariableTracker | None,
  1389. kwdefaults: VariableTracker | None,
  1390. annotations: VariableTracker | None,
  1391. closure: VariableTracker | None,
  1392. # This is present when this function is created by
  1393. # `functools.wrap(wrapped_fn)(this_fn)`.
  1394. wrapped_fn: VariableTracker | None = None,
  1395. **kwargs: Any,
  1396. ) -> None:
  1397. if kwargs.get("mutation_type") is None:
  1398. kwargs.update(mutation_type=AttributeMutationNew())
  1399. super().__init__(**kwargs)
  1400. assert isinstance(fn_name.as_python_constant(), str)
  1401. assert isinstance(code.as_python_constant(), types.CodeType)
  1402. assert isinstance(f_globals, dict)
  1403. self.fn_name = fn_name
  1404. self.code = code
  1405. self.f_globals = f_globals
  1406. self.defaults = defaults
  1407. self.kwdefaults = kwdefaults
  1408. self.annotations = annotations
  1409. self.closure = closure
  1410. self.wrapped_fn: VariableTracker | None = wrapped_fn
  1411. def self_args(self) -> list[VariableTracker]:
  1412. return []
  1413. def as_python_constant(self) -> types.FunctionType:
  1414. return self.get_function()
  1415. def get_code(self) -> types.CodeType:
  1416. return self.code.as_python_constant()
  1417. def python_type(self) -> type[types.FunctionType]:
  1418. return types.FunctionType
  1419. def get_function(self, _converting: set[int] | None = None) -> types.FunctionType:
  1420. # _converting is used a way to break cycles when
  1421. # two nested_functions refer to each other.
  1422. from .base import AsPythonConstantNotImplementedError
  1423. self_id = id(self)
  1424. if _converting is None:
  1425. _converting = set()
  1426. if self_id in _converting:
  1427. raise ClosureConversionError(
  1428. "cycle detected in mutually recursive closures"
  1429. )
  1430. _converting.add(self_id)
  1431. try:
  1432. return self._get_function_impl(_converting)
  1433. except AsPythonConstantNotImplementedError as e:
  1434. raise ClosureConversionError(
  1435. "failed to convert closure cell to Python constant"
  1436. ) from e
  1437. finally:
  1438. _converting.discard(self_id)
  1439. def is_python_constant(self) -> bool:
  1440. try:
  1441. self.as_python_constant()
  1442. return True
  1443. except (NotImplementedError, Unsupported):
  1444. return False
  1445. def _get_function_impl(self, _converting: set[int]) -> types.FunctionType:
  1446. closure_cells = None
  1447. if self.closure:
  1448. from torch._dynamo.symbolic_convert import InstructionTranslator
  1449. tx = InstructionTranslator.current_tx()
  1450. cells = []
  1451. for cell_var in self.closure.items: # type: ignore[attr-defined]
  1452. # Get the cell contents from side_effects or pre_existing_contents
  1453. # load_cell will replay the side-effects
  1454. cell_contents = tx.output.side_effects.load_cell(cell_var)
  1455. # Check for self-referential closure (function capturing itself for recursion)
  1456. # For example:
  1457. # def outer():
  1458. # def helper(n):
  1459. # if n <= 0:
  1460. # return 0
  1461. # return n + helper(n - 1) # helper calls itself
  1462. # return helper
  1463. if cell_contents is self:
  1464. raise ClosureConversionError("self-referential nested function")
  1465. # If the cell contents is a NestedUserFunctionVariable, call get_function
  1466. # directly to properly propagate the _converting set for cycle detection
  1467. if isinstance(cell_contents, NestedUserFunctionVariable):
  1468. value = cell_contents.get_function(_converting)
  1469. else:
  1470. value = cell_contents.as_python_constant()
  1471. cells.append(make_cell(value))
  1472. closure_cells = tuple(cells)
  1473. func = types.FunctionType(
  1474. self.code.as_python_constant(),
  1475. self.f_globals,
  1476. self.fn_name.as_python_constant(),
  1477. argdefs=None,
  1478. closure=closure_cells,
  1479. )
  1480. if self.defaults:
  1481. func.__defaults__ = self.defaults.as_python_constant()
  1482. if self.kwdefaults:
  1483. func.__kwdefaults__ = self.kwdefaults.as_python_constant()
  1484. if self.annotations:
  1485. annotations = self.annotations.as_python_constant()
  1486. if isinstance(annotations, tuple):
  1487. from itertools import pairwise
  1488. annotations = dict(pairwise(annotations))
  1489. # TypeError: __annotations__ must be set to a dict object
  1490. assert isinstance(annotations, dict)
  1491. func.__annotations__ = annotations
  1492. return func
  1493. def call_setattr(
  1494. self,
  1495. tx: "InstructionTranslator",
  1496. name_var: VariableTracker,
  1497. val: VariableTracker,
  1498. ) -> VariableTracker:
  1499. tx.output.side_effects.store_attr(self, name_var.value, val) # type: ignore[attr-defined]
  1500. return CONSTANT_VARIABLE_NONE
  1501. def call_method(
  1502. self,
  1503. tx: "InstructionTranslator",
  1504. name: str,
  1505. args: Sequence[VariableTracker],
  1506. kwargs: dict[str, VariableTracker],
  1507. ) -> VariableTracker:
  1508. if name == "__setattr__":
  1509. return self.call_setattr(tx, *args)
  1510. return super().call_method(tx, name, list(args), kwargs)
  1511. def has_closure(self) -> bool:
  1512. return self.closure is not None
  1513. def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any:
  1514. if name == "__name__":
  1515. return self.get_name()
  1516. if name == "__code__":
  1517. return self.get_code()
  1518. if name == "__defaults__":
  1519. d = getattr(self, "defaults", None)
  1520. return d.as_python_constant() if d else None
  1521. return super().const_getattr(tx, name)
  1522. def call_obj_hasattr(
  1523. self, tx: "InstructionTranslator", name: str
  1524. ) -> ConstantVariable:
  1525. if name == "__code__":
  1526. return variables.ConstantVariable.create(hasattr(self, "code"))
  1527. if name == "__defaults__":
  1528. return variables.ConstantVariable.create(hasattr(self, "defaults"))
  1529. return super().call_obj_hasattr(tx, name)
  1530. def has_self(self) -> bool:
  1531. return False
  1532. def get_globals(self) -> dict[str, Any]:
  1533. return self.f_globals
  1534. def bind_args(
  1535. self,
  1536. parent: "InstructionTranslator",
  1537. args: Sequence[VariableTracker],
  1538. kwargs: dict[str, VariableTracker],
  1539. ) -> dict[str, VariableTracker]:
  1540. code = self.get_code()
  1541. func = types.FunctionType(
  1542. code,
  1543. self.f_globals,
  1544. self.fn_name.as_python_constant(),
  1545. tuple(self.defaults.items) if self.defaults else None, # type: ignore[attr-defined]
  1546. tuple(make_cell(None) for _ in range(len(self.get_code().co_freevars))),
  1547. )
  1548. if self.kwdefaults:
  1549. func.__kwdefaults__ = self.kwdefaults.keys_as_python_constant() # type: ignore[attr-defined]
  1550. bound = inspect.signature(func).bind(*args, **kwargs)
  1551. bound.apply_defaults()
  1552. result = dict(bound.arguments.items())
  1553. wrap_args_kwargs(parent.output.root_tx, result) # type: ignore[arg-type]
  1554. init_cellvars(parent, result, code)
  1555. for idx, name in enumerate(code.co_freevars):
  1556. assert name not in result
  1557. cell = self.closure.items[idx] # type: ignore[attr-defined, union-attr]
  1558. result[name] = cell
  1559. return result
  1560. def reconstruct(self, codegen: "PyCodegen") -> None:
  1561. codegen.add_push_null(
  1562. lambda: codegen.load_import_from(__name__, "_create_nested_fn")
  1563. )
  1564. codegen(self.code)
  1565. codegen.extend_output([codegen.create_load_const_unchecked(self.f_globals)])
  1566. codegen(ConstantVariable.create(self.code.value.co_name)) # type: ignore[attr-defined]
  1567. if self.defaults:
  1568. codegen(self.defaults)
  1569. else:
  1570. codegen.extend_output([codegen.create_load_const(None)])
  1571. if self.closure:
  1572. codegen(self.closure)
  1573. else:
  1574. codegen.extend_output([codegen.create_load_const(None)])
  1575. if self.kwdefaults:
  1576. codegen(self.kwdefaults)
  1577. else:
  1578. codegen.extend_output([codegen.create_load_const(None)])
  1579. if self.annotations:
  1580. try:
  1581. annotations = self.annotations.as_python_constant()
  1582. codegen.extend_output(
  1583. [codegen.create_load_const_unchecked(annotations)]
  1584. )
  1585. except NotImplementedError:
  1586. codegen(self.annotations)
  1587. else:
  1588. codegen.extend_output([codegen.create_load_const(None)])
  1589. codegen.extend_output(create_call_function(7, False))
  1590. if self.wrapped_fn:
  1591. codegen.add_push_null(
  1592. lambda: codegen.load_import_from("functools", "wraps")
  1593. )
  1594. codegen(self.wrapped_fn)
  1595. codegen.extend_output(create_call_function(1, False))
  1596. codegen.extend_output(create_rot_n(2))
  1597. codegen.extend_output(create_call_function(1, True))
  1598. # codegen attributes
  1599. tx = codegen.tx
  1600. if tx.output.side_effects.has_pending_mutation(self):
  1601. for name, value in tx.output.side_effects.store_attr_mutations[
  1602. self
  1603. ].items():
  1604. codegen.dup_top()
  1605. codegen(value)
  1606. codegen.extend_output(create_rot_n(2))
  1607. codegen.store_attr(name)
  1608. class WrappedNestedUserFunctionVariable(NestedUserFunctionVariable):
  1609. def __init__(
  1610. self,
  1611. wrapped: NestedUserFunctionVariable,
  1612. context: "ContextWrappingVariable",
  1613. **kwargs: Any,
  1614. ) -> None:
  1615. kwargs.pop("fn_name", None)
  1616. kwargs.pop("code", None)
  1617. kwargs.pop("f_globals", None)
  1618. kwargs.pop("defaults", None)
  1619. kwargs.pop("kwdefaults", None)
  1620. kwargs.pop("annotations", None)
  1621. kwargs.pop("closure", None)
  1622. kwargs.pop("wrapped_fn", None)
  1623. super().__init__(
  1624. wrapped.fn_name,
  1625. wrapped.code,
  1626. wrapped.f_globals,
  1627. wrapped.defaults,
  1628. wrapped.kwdefaults,
  1629. wrapped.annotations,
  1630. wrapped.closure,
  1631. wrapped.wrapped_fn,
  1632. )
  1633. self.wrapped = wrapped
  1634. self.context = context
  1635. def call_function(
  1636. self,
  1637. tx: "InstructionTranslator",
  1638. args: Sequence[VariableTracker],
  1639. kwargs: dict[str, VariableTracker],
  1640. ) -> VariableTracker:
  1641. self.context.enter(tx)
  1642. result = super().call_function(tx, args, kwargs)
  1643. self.context.exit(tx)
  1644. return result
  1645. def reconstruct(self, codegen: "PyCodegen") -> None:
  1646. codegen.add_push_null(lambda: codegen(self.context))
  1647. codegen(self.wrapped)
  1648. codegen.extend_output(create_call_function(1, False))
  1649. class SkipFunctionVariable(VariableTracker):
  1650. _nonvar_fields = {
  1651. "value",
  1652. "reason",
  1653. *VariableTracker._nonvar_fields,
  1654. }
  1655. def __init__(self, value: Any, reason: str | None = None, **kwargs: Any) -> None:
  1656. super().__init__(**kwargs)
  1657. self.value = value
  1658. self.reason = reason
  1659. def as_python_constant(self) -> Any:
  1660. return self.value
  1661. @classmethod
  1662. def create_with_source(cls, value: Any, source: Source) -> "SkipFunctionVariable":
  1663. # Use closure match guard (i.e. guard on __code__ object instead of
  1664. # function id) to avoid guarding on nested functions.
  1665. if inspect.getattr_static(value, "_torchdynamo_disable", False):
  1666. # For torch._dynamo.disable function, ensure that the original
  1667. # function is guarded. Otherwise, the else branch will guard on the
  1668. # _dynamo.disable.__code__
  1669. guard_on_source = source
  1670. guard_on_value = value
  1671. while getattr(guard_on_value, "_torchdynamo_orig_callable", False):
  1672. guard_on_value = guard_on_value._torchdynamo_orig_callable
  1673. guard_on_source = AttrSource(
  1674. guard_on_source, "_torchdynamo_orig_callable"
  1675. )
  1676. guard_on_source.make_guard(GuardBuilder.CLOSURE_MATCH)
  1677. elif inspect.isbuiltin(value):
  1678. install_guard(source.make_guard(GuardBuilder.BUILTIN_MATCH))
  1679. elif not is_wrapper_or_member_descriptor(value):
  1680. # These descriptors are not guaranteed to return the same object on
  1681. # attribute lookup. They are unlikely to be changed, so we can skip
  1682. # guarding them.
  1683. install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH))
  1684. return cls(value, source=source)
  1685. def call_function(
  1686. self,
  1687. tx: "InstructionTranslator",
  1688. args: Sequence[VariableTracker],
  1689. kwargs: dict[str, VariableTracker],
  1690. ) -> VariableTracker:
  1691. if inspect.getattr_static(self.value, "_torchdynamo_disable", False):
  1692. msg = inspect.getattr_static(self.value, "_torchdynamo_disable_msg", None)
  1693. unimplemented(
  1694. gb_type="Skip calling `torch.compiler.disable()`d function",
  1695. context=str(self.value),
  1696. explanation=f"Skip calling function `{self.value}` since it was wrapped "
  1697. f"with `torch.compiler.disable` (reason: {msg})",
  1698. hints=[
  1699. "Remove the `torch.compiler.disable` call",
  1700. ],
  1701. )
  1702. elif self.value is torch._dynamo.graph_break:
  1703. graph_break_msg = kwargs.get("msg")
  1704. if graph_break_msg:
  1705. graph_break_msg = graph_break_msg.as_python_constant()
  1706. unimplemented(
  1707. gb_type="Call to `torch._dynamo.graph_break()`",
  1708. context=f"Called `torch._dynamo.graph_break()` with args `{args}`, kwargs `{kwargs}`",
  1709. explanation=f"User-inserted graph break. Message: {graph_break_msg}",
  1710. hints=[
  1711. "Remove the `torch._dynamo.graph_break()` call.",
  1712. ],
  1713. )
  1714. elif self.value is torch._dynamo.skip_frame:
  1715. skip_frame_msg = kwargs.get("msg")
  1716. if skip_frame_msg:
  1717. skip_frame_msg = skip_frame_msg.as_python_constant()
  1718. else:
  1719. skip_frame_msg = ""
  1720. unimplemented(
  1721. gb_type="Call to `torch._dynamo.skip_frame()`",
  1722. context=f"Called `torch._dynamo.skip_frame()` with args `{args}`, kwargs `{kwargs}`. "
  1723. f"Skipping frame {format_frame_info(tx.f_code)}.",
  1724. explanation=f"User-inserted skip frame. Message: {skip_frame_msg}",
  1725. hints=[
  1726. "Remove the `torch._dynamo.skip_frame()` call.",
  1727. ],
  1728. skip_frame=True,
  1729. )
  1730. elif self.value is torch._dynamo.step_unsupported:
  1731. try:
  1732. unimplemented(
  1733. gb_type="Call to `torch._dynamo.step_unsupported()`",
  1734. context="",
  1735. explanation="User-inserted step_unsupported.",
  1736. hints=[
  1737. "Remove the `torch._dynamo.step_unsupported()` call.",
  1738. ],
  1739. )
  1740. except Unsupported as e:
  1741. raise StepUnsupported(e.msg) from None
  1742. else:
  1743. if config.dont_skip_tracing:
  1744. from .builder import SourcelessBuilder
  1745. # re-build the function, attempting to not skip
  1746. rebuilt_fn = SourcelessBuilder.create(tx, self.value)
  1747. # if we still get SkipFunctionVariable, then we *really* should skip this function
  1748. if not isinstance(rebuilt_fn, SkipFunctionVariable):
  1749. return rebuilt_fn.call_function(tx, args, kwargs)
  1750. qualname = getattr(self.value, "__qualname__", "<unknown qualname>")
  1751. module_or = getattr(self.value, "__module__", None)
  1752. module_name = "<unknown module>" if module_or is None else str(module_or)
  1753. try:
  1754. path = inspect.getfile(self.value)
  1755. explanation = (
  1756. f"Dynamo developers have intentionally marked that the function `{qualname}` "
  1757. f"in file `{path}` should not be traced."
  1758. )
  1759. hints = [
  1760. f"Avoid calling the function `{qualname}`.",
  1761. ]
  1762. # TODO improve trace_rules reasoning to provide better hints.
  1763. # How do we tell that a function/file should NOT be removed from skip files?
  1764. # Do a very basic check for now.
  1765. if "_dynamo" not in path:
  1766. hints += [
  1767. f"Apply `@torch._dynamo.dont_skip_tracing` to the function `{qualname}` "
  1768. "to force tracing into the function. "
  1769. "More graph breaks may occur as a result of attempting to trace into the function.",
  1770. "Please file an issue to PyTorch.",
  1771. ]
  1772. except TypeError:
  1773. known_python_builtin_modules = {"_abc", "_warnings"}
  1774. if module_or in known_python_builtin_modules:
  1775. explanation = (
  1776. f"Dynamo does not know how to trace the Python builtin "
  1777. f"`{module_name}.{qualname}`."
  1778. )
  1779. hints = [
  1780. "If you are attempting to call a logging function (e.g. `_warnings.warn`), "
  1781. "you can try adding it to `torch._dynamo.config.reorderable_logging_functions`.",
  1782. "Please file an issue on GitHub "
  1783. "so the PyTorch team can add support for it. ",
  1784. ]
  1785. elif module_or is not None and module_or.startswith("optree"):
  1786. explanation = f"Dynamo cannot trace optree C/C++ function {module_name}.{qualname}."
  1787. hints = [
  1788. " Consider using torch.utils._pytree - "
  1789. "https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py"
  1790. ]
  1791. # also warn on it because most users won't see the graph break message
  1792. torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
  1793. else:
  1794. explanation = (
  1795. f"Dynamo does not know how to trace the builtin `{module_name}.{qualname}.` "
  1796. f"This function is either a Python builtin (e.g. _warnings.warn) "
  1797. f"or a third-party C/C++ Python extension (perhaps created with pybind)."
  1798. )
  1799. hints = [
  1800. "If it is a Python builtin, please file an issue on GitHub "
  1801. "so the PyTorch team can add support for it and see the next case for a workaround.",
  1802. "If it is a third-party C/C++ Python extension, please "
  1803. "either wrap it into a PyTorch-understood custom operator "
  1804. "(see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html "
  1805. "for more details) or, if it is traceable, use "
  1806. "`torch.compiler.allow_in_graph`.",
  1807. ]
  1808. # also warn on it because most users won't see the graph break message
  1809. torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
  1810. if qualname == "allow_in_graph":
  1811. explanation = (
  1812. "Found an allow_in_graph decorator to a function which "
  1813. "is created inside the parent function that is getting "
  1814. "compiled. This is not supported for now."
  1815. )
  1816. # pyrefly: ignore [implicit-any]
  1817. hints = []
  1818. reason = self.reason if self.reason else "<missing reason>"
  1819. unimplemented(
  1820. gb_type="Attempted to call function marked as skipped",
  1821. context=f"module: {module_name}, qualname: {qualname}, skip reason: {reason}",
  1822. explanation=explanation,
  1823. hints=hints,
  1824. )
  1825. def call_obj_hasattr(
  1826. self, tx: "InstructionTranslator", name: str
  1827. ) -> ConstantVariable:
  1828. return variables.ConstantVariable.create(hasattr(self.value, name))
  1829. def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
  1830. if name in cmp_name_to_op_mapping:
  1831. return variables.GetAttrVariable(self, name)
  1832. return fn_var_getattr(tx, self.value, self.source, name)
  1833. def is_python_hashable(self) -> bool:
  1834. return True
  1835. def get_python_hash(self) -> int:
  1836. return hash(self.value)
  1837. def is_python_equal(self, other: object) -> bool:
  1838. return (
  1839. isinstance(other, VariableTracker)
  1840. and self.as_python_constant() == other.as_python_constant()
  1841. )
  1842. class WrappedSkipFunctionVariable(SkipFunctionVariable):
  1843. def __init__(
  1844. self,
  1845. wrapped: SkipFunctionVariable,
  1846. context: "ContextWrappingVariable",
  1847. **kwargs: Any,
  1848. ) -> None:
  1849. kwargs.pop("value", None)
  1850. kwargs.pop("reason", None)
  1851. super().__init__(wrapped.value, reason=wrapped.reason, **kwargs)
  1852. self.wrapped = wrapped
  1853. self.context = context
  1854. def call_function(
  1855. self,
  1856. tx: "InstructionTranslator",
  1857. args: Sequence[VariableTracker],
  1858. kwargs: dict[str, VariableTracker],
  1859. ) -> VariableTracker:
  1860. self.context.enter(tx)
  1861. result = super().call_function(tx, args, kwargs)
  1862. self.context.exit(tx)
  1863. return result
  1864. def reconstruct(self, codegen: "PyCodegen") -> None:
  1865. codegen.add_push_null(lambda: codegen(self.context))
  1866. codegen(self.wrapped)
  1867. codegen.extend_output(create_call_function(1, False))
  1868. class WrapperUserFunctionVariable(VariableTracker):
  1869. """
  1870. Used to represent a wrapper object that contains the actual callable as an
  1871. attribute. For example, torch.jit.script/trace have the original function at
  1872. their _torchdynamo_inline attribute. Similarly, functions with
  1873. __script_if_tracing_wrapper have the original attr at "__original_fn".
  1874. """
  1875. def __init__(self, wrapper_obj: Any, attr_to_trace: str, **kwargs: Any) -> None:
  1876. super().__init__(**kwargs)
  1877. self.wrapper_obj = wrapper_obj
  1878. self.attr_to_trace = attr_to_trace
  1879. def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
  1880. if name == self.attr_to_trace:
  1881. val = getattr(self.wrapper_obj, self.attr_to_trace)
  1882. source = self.source and AttrSource(self.source, name)
  1883. return VariableTracker.build(tx, val, source)
  1884. return super().var_getattr(tx, name)
  1885. def self_args(self) -> list[VariableTracker]:
  1886. return []
  1887. def call_function(
  1888. self,
  1889. tx: "InstructionTranslator",
  1890. args: Sequence[VariableTracker],
  1891. kwargs: dict[str, VariableTracker],
  1892. ) -> VariableTracker:
  1893. if hasattr(self.wrapper_obj, "cache_info"):
  1894. target_fn = getattr(self.wrapper_obj, self.attr_to_trace, None)
  1895. module_name = getattr(target_fn, "__module__", "") or ""
  1896. if module_name.split(".", maxsplit=1)[0] != "torch":
  1897. frame_summary = tx.frame_summary()
  1898. filename = os.path.basename(frame_summary.filename)
  1899. lineno = frame_summary.lineno
  1900. msg = (
  1901. "Dynamo detected a call to a `functools.lru_cache`-wrapped "
  1902. f"function at '{filename}:{lineno}'. Dynamo ignores the "
  1903. "cache wrapper and directly traces the wrapped function. "
  1904. "Silent incorrectness is only a *potential* risk, not "
  1905. "something we have observed. "
  1906. "Enable TORCH_LOGS=+dynamo for a DEBUG stack trace.\n\n"
  1907. "This call originates from:\n"
  1908. f"{''.join(traceback.format_list([frame_summary]))}"
  1909. )
  1910. torch._dynamo.utils.warn_once(msg)
  1911. dynamo_logger = torch._dynamo.utils.logging.getLogger("torch._dynamo")
  1912. if dynamo_logger.isEnabledFor(logging.DEBUG):
  1913. user_stack = torch._guards.TracingContext.extract_stack()
  1914. user_stack = get_stack_above_dynamo() + user_stack
  1915. frame_loc = (user_stack[-1].filename, user_stack[-1].lineno)
  1916. user_stack_formatted = "".join(traceback.format_list(user_stack))
  1917. user_stack_trace = f"call to a lru_cache wrapped function at: {frame_loc[0]}:{frame_loc[1]}\n"
  1918. user_stack_trace += str(user_stack_formatted)
  1919. dynamo_logger.debug(user_stack_trace)
  1920. all_args = self.self_args() + list(args)
  1921. return variables.UserFunctionVariable(
  1922. polyfills.getattr_and_trace # type: ignore[arg-type]
  1923. ).call_function(
  1924. tx,
  1925. [self, variables.ConstantVariable(self.attr_to_trace), *all_args],
  1926. kwargs,
  1927. )
  1928. class WrapperUserMethodVariable(WrapperUserFunctionVariable):
  1929. """
  1930. Similar to WrapperUserFunctionVariable, but for methods. The only delta is
  1931. saving the vt for `self` object of the method which is then used by
  1932. WrapperUserFunctionVariable in `call_function` method.
  1933. """
  1934. def __init__(
  1935. self,
  1936. wrapper_obj: Any,
  1937. attr_to_trace: str,
  1938. self_obj: VariableTracker,
  1939. **kwargs: Any,
  1940. ) -> None:
  1941. super().__init__(wrapper_obj, attr_to_trace, **kwargs)
  1942. self.obj = self_obj
  1943. def self_args(self) -> list[VariableTracker]:
  1944. return [self.obj]
  1945. def _traceable_collective_remaps() -> dict[Any, Any]:
  1946. # We can't rely on importing from distributed, since it's not always built
  1947. if torch.distributed.is_available():
  1948. from torch.distributed._functional_collectives import (
  1949. traceable_collective_remaps,
  1950. )
  1951. return traceable_collective_remaps
  1952. return {}
  1953. def _traceable_collectives_source(
  1954. tx: "InstructionTranslator", fn: Callable[..., Any]
  1955. ) -> AttrSource:
  1956. assert torch.distributed.is_available(), "Illegal invocation."
  1957. assert fn in _traceable_collective_remaps().values()
  1958. inner_name = fn.__name__
  1959. path_source = tx.import_source("torch.distributed._functional_collectives")
  1960. return AttrSource(path_source, inner_name)
  1961. class CollectiveFunctionRewriteVariable(UserFunctionVariable):
  1962. """
  1963. Some of the torch.distributed.* collective APIs are possible to rewrite to 'traceable' collectives.
  1964. This class provides both a way to check if a function is remappable, and perform the remapping.
  1965. In the case that a function is 'remappable' but only for some combinations of call-time arguments,
  1966. we check the args at `call_function` time and fall back to graph-breaking if needed. This is no worse
  1967. than status-quo as we currently graph-break on all distributed.* collectives.
  1968. """
  1969. def __init__(
  1970. self,
  1971. fn: Callable[..., Any],
  1972. *,
  1973. replacement_var: UserFunctionVariable,
  1974. **kwargs: Any,
  1975. ) -> None:
  1976. super().__init__(fn, **kwargs) # type: ignore[arg-type]
  1977. assert isinstance(replacement_var, UserFunctionVariable)
  1978. self.replacement_var = replacement_var
  1979. @staticmethod
  1980. def create(
  1981. tx: "InstructionTranslator",
  1982. old_fn: Callable[..., Any],
  1983. source: Source,
  1984. **options: Any,
  1985. ) -> "CollectiveFunctionRewriteVariable":
  1986. new_fn, new_source = CollectiveFunctionRewriteVariable.rewrite(tx, old_fn)
  1987. return CollectiveFunctionRewriteVariable(
  1988. old_fn,
  1989. replacement_var=UserFunctionVariable(new_fn, source=new_source, **options),
  1990. source=source,
  1991. **options,
  1992. )
  1993. @staticmethod
  1994. def can_rewrite(variable: Any) -> bool:
  1995. return (
  1996. inspect.isfunction(variable) and variable in _traceable_collective_remaps()
  1997. )
  1998. @staticmethod
  1999. def rewrite(
  2000. tx: "InstructionTranslator", fn: Callable[..., Any]
  2001. ) -> tuple[Any, AttrSource]:
  2002. new_fn = _traceable_collective_remaps()[fn]
  2003. return new_fn, _traceable_collectives_source(tx, new_fn)
  2004. def call_function(
  2005. self,
  2006. tx: "InstructionTranslator",
  2007. args: Sequence[VariableTracker],
  2008. kwargs: dict[str, VariableTracker],
  2009. ) -> VariableTracker:
  2010. # call_function must check any unsupported arguments and graph-break.
  2011. # It's safe to assume args/kwargs from orig_fn map 1:1 to args/kwargs of remapped_fn,
  2012. # since that's the contract for putting a mapping in `traceable_collective_remaps`
  2013. import torch.distributed as dist
  2014. from torch.distributed._functional_collectives import REDUCE_OP_TO_STR
  2015. # Merge args into kwargs so positional and keyword args
  2016. # can be processed the same way.
  2017. signature = inspect.signature(self.fn)
  2018. kwargs = dict(signature.bind(*args, **kwargs).arguments)
  2019. args = ()
  2020. if "async_op" in kwargs and kwargs["async_op"].as_python_constant():
  2021. unimplemented(
  2022. gb_type="async_op=True for distributed collectives",
  2023. context=f"{self.fn}, {args=}, {kwargs=}",
  2024. explanation=f"`torch.compile` doesn't support `async_op=True for {self.fn}",
  2025. hints=[
  2026. *graph_break_hints.SUPPORTABLE,
  2027. ],
  2028. )
  2029. if self.fn in (
  2030. dist.all_reduce,
  2031. dist.reduce_scatter_tensor,
  2032. # pyrefly: ignore [deprecated]
  2033. dist._reduce_scatter_base,
  2034. ):
  2035. reduce_op_var = kwargs.get("op")
  2036. reduce_op = (
  2037. reduce_op_var.value # type: ignore[attr-defined]
  2038. if reduce_op_var is not None
  2039. else signature.parameters["op"].default
  2040. )
  2041. if reduce_op not in REDUCE_OP_TO_STR:
  2042. raise ValueError(f"Unsupported all_reduce op: {reduce_op}")
  2043. kwargs["op"] = variables.ConstantVariable.create(
  2044. REDUCE_OP_TO_STR[reduce_op]
  2045. )
  2046. return self.replacement_var.call_function(tx, args, kwargs)
  2047. class FunctoolsWrapsVariable(UserFunctionVariable):
  2048. def call_function(
  2049. self,
  2050. tx: "InstructionTranslator",
  2051. args: Sequence[VariableTracker],
  2052. kwargs: dict[str, VariableTracker],
  2053. ) -> VariableTracker:
  2054. if not kwargs and len(args) == 1:
  2055. def wraps(fn: Any) -> VariableTracker:
  2056. if isinstance(fn, variables.NestedUserFunctionVariable):
  2057. return fn.clone(wrapped_fn=args[0])
  2058. unimplemented(
  2059. gb_type="functools.wraps",
  2060. context=f"{fn}",
  2061. explanation="`torch.compile` can't trace `functools.wraps` on functions defined outside the compile region",
  2062. hints=[
  2063. *graph_break_hints.SUPPORTABLE,
  2064. ],
  2065. )
  2066. return variables.LambdaVariable(wraps)
  2067. return super().call_function(tx, args, kwargs)
  2068. class CollectionsNamedTupleFunction(UserFunctionVariable):
  2069. def as_python_constant(self) -> Any:
  2070. return self.fn
  2071. def call_function(
  2072. self,
  2073. tx: "InstructionTranslator",
  2074. args: Sequence[VariableTracker],
  2075. kwargs: dict[str, VariableTracker],
  2076. ) -> VariableTracker:
  2077. constant_args = check_constant_args(args, kwargs)
  2078. if constant_args:
  2079. try:
  2080. value = self.fn(
  2081. *[x.as_python_constant() for x in args],
  2082. **{k: v.as_python_constant() for k, v in kwargs.items()},
  2083. )
  2084. except TypeError as exc:
  2085. raise_observed_exception(
  2086. type(exc),
  2087. tx,
  2088. args=list(map(ConstantVariable.create, exc.args)),
  2089. )
  2090. return variables.UserDefinedClassVariable(
  2091. # pyrefly: ignore[unbound-name]
  2092. value,
  2093. mutation_type=ValueMutationNew(),
  2094. )
  2095. unimplemented(
  2096. gb_type="namedtuple construction",
  2097. context=f"{args=}, {kwargs=}",
  2098. explanation="`torch.compile` only support certain input types for namedtuple",
  2099. hints=[
  2100. *graph_break_hints.SUPPORTABLE,
  2101. ],
  2102. )
  2103. class FunctoolsPartialVariable(VariableTracker):
  2104. _nonvar_fields = {
  2105. "original_cache_hash",
  2106. *VariableTracker._nonvar_fields,
  2107. }
  2108. def __init__(
  2109. self,
  2110. func: VariableTracker,
  2111. args: Sequence[VariableTracker],
  2112. keywords: dict[str, VariableTracker],
  2113. original_cache_hash: Any = None,
  2114. **kwargs: Any,
  2115. ) -> None:
  2116. super().__init__(**kwargs)
  2117. self.func = func
  2118. assert isinstance(args, list)
  2119. self.args = args
  2120. assert isinstance(keywords, dict)
  2121. self.keywords = keywords
  2122. # fake_value is used for id calculation. Creating this value and id'ng
  2123. # on it is sufficient for the tracing purposes.
  2124. self.fake_value = functools.partial(identity)
  2125. # Store cache_hash from the original partial for SAC context_fn caching
  2126. self.original_cache_hash = original_cache_hash
  2127. def python_type(self) -> type:
  2128. return functools.partial
  2129. def reconstruct(self, codegen: "PyCodegen") -> None:
  2130. codegen.add_push_null(lambda: codegen.load_import_from("functools", "partial"))
  2131. codegen(self.func)
  2132. if self.args:
  2133. codegen.foreach(self.args)
  2134. if not self.keywords:
  2135. codegen.extend_output(create_call_function(len(self.args) + 1, False))
  2136. return
  2137. codegen.foreach(self.keywords.values())
  2138. keys = tuple(self.keywords.keys())
  2139. codegen.extend_output(
  2140. codegen.create_call_function_kw(len(keys) + len(self.args) + 1, keys, False)
  2141. )
  2142. def get_function(self) -> Any:
  2143. return self.as_python_constant()
  2144. def call_function(
  2145. self,
  2146. tx: "InstructionTranslator",
  2147. args: Sequence[VariableTracker],
  2148. kwargs: dict[str, VariableTracker],
  2149. ) -> VariableTracker:
  2150. merged_args = self.args + list(args)
  2151. merged_kwargs = {**self.keywords, **kwargs}
  2152. return self.func.call_function(tx, merged_args, merged_kwargs)
  2153. def call_obj_hasattr(
  2154. self, tx: "InstructionTranslator", name: str
  2155. ) -> ConstantVariable:
  2156. # functools.partial uses slots, so attributes are constant
  2157. return variables.ConstantVariable.create(
  2158. hasattr(functools.partial(identity), name)
  2159. )
  2160. def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
  2161. source = self.source and AttrSource(self.source, name)
  2162. # Handle __slots__
  2163. if name == "func":
  2164. return self.func
  2165. if name == "args":
  2166. return variables.ListVariable(self.args, source=source)
  2167. if name == "keywords":
  2168. items = {ConstantVariable.create(k): v for k, v in self.keywords.items()}
  2169. return variables.ConstDictVariable(items, source=source)
  2170. if name in cmp_name_to_op_mapping:
  2171. return variables.GetAttrVariable(self, name)
  2172. raise_observed_exception(AttributeError, tx)
  2173. def as_python_constant(self) -> Any:
  2174. return functools.partial(
  2175. self.func.as_python_constant(),
  2176. *[arg.as_python_constant() for arg in self.args],
  2177. **{k: v.as_python_constant() for k, v in self.keywords.items()},
  2178. )
  2179. def guard_as_python_constant(self) -> Any:
  2180. """Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants"""
  2181. result = functools.partial(
  2182. self.func.guard_as_python_constant(),
  2183. *[v.guard_as_python_constant() for v in self.args],
  2184. **{k: v.guard_as_python_constant() for k, v in self.keywords.items()},
  2185. )
  2186. # Preserve cache_hash for SAC context_fn caching
  2187. if self.original_cache_hash is not None:
  2188. result.cache_hash = self.original_cache_hash # type: ignore[missing-attribute]
  2189. return result
  2190. def is_python_hashable(self) -> bool:
  2191. return (
  2192. self.func.is_python_hashable()
  2193. and all(arg.is_python_hashable() for arg in self.args)
  2194. and all(value.is_python_hashable() for value in self.keywords.values())
  2195. )
  2196. def get_python_hash(self) -> int:
  2197. func_hash = self.func.get_python_hash()
  2198. args_hash = (arg.get_python_hash() for arg in self.args)
  2199. values_hash = (value.get_python_hash() for value in self.keywords.values())
  2200. return hash((func_hash, *args_hash, *values_hash))
  2201. def is_python_equal(self, other: object) -> bool:
  2202. return (
  2203. isinstance(other, FunctoolsPartialVariable)
  2204. and self.func.is_python_equal(other.func)
  2205. and all(
  2206. arg_a.is_python_equal(arg_b)
  2207. for (arg_a, arg_b) in zip(self.args, other.args)
  2208. )
  2209. and all(
  2210. value_a.is_python_equal(value_b)
  2211. for (value_a, value_b) in zip(
  2212. self.keywords.values(), other.keywords.values()
  2213. )
  2214. )
  2215. )
  2216. class PolyfilledFunctionVariable(VariableTracker):
  2217. _nonvar_fields = {
  2218. "fn",
  2219. "wrapped_fn",
  2220. "traceable_fn",
  2221. *VariableTracker._nonvar_fields,
  2222. }
  2223. @classmethod
  2224. @functools.cache
  2225. def _get_polyfill_handlers(cls) -> dict[Callable[..., Any], types.FunctionType]:
  2226. return {}
  2227. @classmethod
  2228. def create_with_source(
  2229. cls, value: Any, source: Source
  2230. ) -> "PolyfilledFunctionVariable":
  2231. install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH))
  2232. return cls(value, source=source)
  2233. def __init__(self, fn: _F, **kwargs: Any) -> None:
  2234. super().__init__(**kwargs)
  2235. # pyrefly: ignore[invalid-type-var]
  2236. self.fn: _F = fn
  2237. handler = self._get_polyfill_handlers().get(fn, fn)
  2238. traceable_fn = None
  2239. assert callable(handler), f"Polyfill handler {handler} is not callable for {fn}"
  2240. for candidate_attr in (
  2241. "__torch_dynamo_polyfill__", # registered polyfill
  2242. "__python_implementation__", # self handler from third-party libraries
  2243. ):
  2244. candidate = getattr(handler, candidate_attr, None)
  2245. if candidate:
  2246. assert callable(candidate)
  2247. traceable_fn = candidate
  2248. break
  2249. else:
  2250. raise RuntimeError(
  2251. f"Polyfill handler {handler} does not have a traceable function"
  2252. )
  2253. self.wrapped_fn = handler
  2254. # pyrefly: ignore[invalid-type-var]
  2255. self.traceable_fn: _F = traceable_fn
  2256. @property
  2257. def polyfill_fn(self) -> Callable[..., Any]:
  2258. return self.traceable_fn
  2259. def can_constant_fold_through(self) -> bool:
  2260. return getattr(
  2261. self.wrapped_fn, "__torch_dynamo_can_constant_fold_through__", False
  2262. )
  2263. def get_function(self) -> Any:
  2264. return self.as_python_constant()
  2265. def call_function(
  2266. self,
  2267. tx: "InstructionTranslator",
  2268. args: Sequence[VariableTracker],
  2269. kwargs: dict[str, VariableTracker],
  2270. ) -> VariableTracker:
  2271. if self.can_constant_fold_through() and check_unspec_or_constant_args(
  2272. args, kwargs
  2273. ):
  2274. result = (
  2275. self.fn( # use the original function which is faster than the polyfill
  2276. *[x.as_python_constant() for x in args],
  2277. **{k: v.as_python_constant() for k, v in kwargs.items()},
  2278. )
  2279. )
  2280. return VariableTracker.build(tx, result)
  2281. # Special case for sum on tuple/list of ints
  2282. if (
  2283. self.fn is builtins.sum
  2284. and len(args) == 1
  2285. and not kwargs
  2286. and isinstance(args[0], (variables.ListVariable, variables.TupleVariable))
  2287. and all(
  2288. (x.is_python_constant() and isinstance(x.as_python_constant(), int))
  2289. or (isinstance(x, variables.SymNodeVariable) and x.python_type() is int)
  2290. for x in args[0].items
  2291. )
  2292. ):
  2293. return variables.SymNodeVariable.create(
  2294. tx,
  2295. tx.output.create_proxy(
  2296. "call_function",
  2297. torch.sym_sum,
  2298. (tuple(a.as_proxy() for a in args[0].items),),
  2299. {},
  2300. ),
  2301. sym_num=torch.sym_sum(
  2302. [
  2303. (
  2304. x.as_python_constant()
  2305. if x.is_python_constant()
  2306. else x.sym_num # type: ignore[attr-defined]
  2307. )
  2308. for x in args[0].items
  2309. ]
  2310. ),
  2311. )
  2312. traceable_function_variable = VariableTracker.build(tx, self.traceable_fn)
  2313. return traceable_function_variable.call_function(tx, args, kwargs)
  2314. def call_method(
  2315. self,
  2316. tx: "InstructionTranslator",
  2317. name: str,
  2318. args: list[VariableTracker],
  2319. kwargs: dict[str, VariableTracker],
  2320. ) -> VariableTracker:
  2321. if name == "__call__":
  2322. return self.call_function(tx, args, kwargs)
  2323. method = getattr(self.fn, name, None)
  2324. if not (method or is_function(method)):
  2325. raise_type_error_exc(tx, f"Cannot find callable {name} in {self.fn}")
  2326. options = {}
  2327. if self.source:
  2328. options["source"] = AttrSource(self.source, name)
  2329. # pyrefly: ignore[bad-specialization]
  2330. polyfilled_method_variable = PolyfilledFunctionVariable(method, **options)
  2331. return polyfilled_method_variable.call_function(tx, args, kwargs)
  2332. def as_python_constant(self) -> Any:
  2333. return self.fn
  2334. class SysFunctionVariable(VariableTracker):
  2335. def __init__(self, value: Any, **kwargs: Any) -> None:
  2336. super().__init__(**kwargs)
  2337. self.value = value
  2338. def exc_info(self, tx: "InstructionTranslator") -> "variables.TupleVariable":
  2339. if len(tx.exn_vt_stack):
  2340. exn = tx.exn_vt_stack[-1]
  2341. typ = exn.exc_type # type: ignore[union-attr]
  2342. tb = exn.var_getattr(tx, "__traceback__")
  2343. items = [VariableTracker.build(tx, typ), exn, tb]
  2344. else:
  2345. items = [
  2346. variables.CONSTANT_VARIABLE_NONE,
  2347. variables.CONSTANT_VARIABLE_NONE,
  2348. variables.CONSTANT_VARIABLE_NONE,
  2349. ]
  2350. return variables.TupleVariable(items) # type: ignore[arg-type]
  2351. def exception(self, tx: "InstructionTranslator") -> VariableTracker:
  2352. return self.exc_info(tx).items[1]
  2353. def call_function(
  2354. self,
  2355. tx: "InstructionTranslator",
  2356. args: Sequence[VariableTracker],
  2357. kwargs: dict[str, VariableTracker],
  2358. ) -> VariableTracker:
  2359. if self.value is sys.exc_info:
  2360. return self.exc_info(tx)
  2361. assert self.value is sys.exception
  2362. return self.exception(tx)
  2363. from torch._higher_order_ops.triton_kernel_wrap import (
  2364. create_tma_experimental_metadata,
  2365. create_tma_stable_metadata,
  2366. TMADescriptorMetadata,
  2367. TritonHOPifier,
  2368. )
  2369. class DynamoTritonHOPifier(TritonHOPifier):
  2370. def raise_unsupported(self, msg: str) -> Never:
  2371. unimplemented(
  2372. gb_type="triton kernel unsupported feature",
  2373. context="",
  2374. explanation=f"Encountered triton kernel unsupported feature: {msg}",
  2375. hints=[],
  2376. )
  2377. def is_callable(self, maybe_callable: VariableTracker) -> bool:
  2378. return isinstance(
  2379. maybe_callable, (NestedUserFunctionVariable, UserFunctionVariable)
  2380. )
  2381. def get_value(self, val: VariableTracker) -> Any:
  2382. return val.value # type: ignore[attr-defined]
  2383. def check_grid(self, grid: "BaseListVariable") -> tuple[torch.fx.proxy.Proxy, ...]:
  2384. from .lists import BaseListVariable
  2385. if isinstance(grid, BaseListVariable):
  2386. return grid.as_proxy()
  2387. else:
  2388. unimplemented(
  2389. gb_type="unsupported grid type for triton hop check_grid",
  2390. context=f"grid type = {type(grid)}",
  2391. explanation="`torch.compile` only supports list-like grid for check_grid",
  2392. hints=[
  2393. *graph_break_hints.SUPPORTABLE,
  2394. ],
  2395. )
  2396. def call_grid(
  2397. self, grid: Any, meta: dict[str, Any], tx: "InstructionTranslator"
  2398. ) -> Any:
  2399. meta_var = {variables.ConstantVariable.create(k): v for k, v in meta.items()}
  2400. grid = grid.call_function(tx, [meta_var], {})
  2401. return grid
  2402. # We use this function to wrap call_prune_configs
  2403. def call_user_defined_fn(
  2404. self,
  2405. user_fn: Callable[..., Any],
  2406. args: Sequence[VariableTracker],
  2407. kwargs: dict[str, VariableTracker],
  2408. tx: Optional["InstructionTranslator"],
  2409. variable: Any,
  2410. ) -> VariableTracker:
  2411. from .builder import SourcelessBuilder
  2412. wrapped_user_function = SourcelessBuilder.create(tx, user_fn) # type: ignore[arg-type]
  2413. result = wrapped_user_function.call_function(tx, args, kwargs)
  2414. return result
  2415. def wrap_user_defined_obj(
  2416. self,
  2417. user_obj: Any,
  2418. tx: Optional["InstructionTranslator"],
  2419. variable: Any,
  2420. name: str,
  2421. ) -> VariableTracker:
  2422. from .builder import VariableBuilder
  2423. assert tx is not None
  2424. wrapped_user_obj = VariableBuilder(
  2425. tx, AttrSource(variable.kernel_source, f"{name}")
  2426. )._wrap(user_obj)
  2427. return wrapped_user_obj
  2428. def maybe_unpack_configs(
  2429. self, configs: Any, tx: Optional["InstructionTranslator"]
  2430. ) -> list[Any]:
  2431. # unpack the list of configs
  2432. configs = configs.unpack_var_sequence(tx)
  2433. # guard_as_python_constant inserts guards for Dynamo to check if the configs object changed.
  2434. configs = [config.guard_as_python_constant() for config in configs]
  2435. return configs
  2436. def maybe_unpack_heuristic_result(self, result: VariableTracker) -> Any:
  2437. if not result.is_python_constant():
  2438. self.raise_unsupported(
  2439. "@triton.heuristics must return constant values because configs can only contain constant values."
  2440. )
  2441. return result.guard_as_python_constant()
  2442. # We need to override call_getitem here so that we can add the source in the case
  2443. # where we call the triton kernel with a grid
  2444. def call_getitem( # type: ignore[override]
  2445. self,
  2446. variable: "TritonKernelVariable",
  2447. args: Sequence[Any],
  2448. ) -> "TritonKernelVariable":
  2449. # __getitem__ should only be called if we don't already have a grid
  2450. # Only grid needs to be passed
  2451. if variable.grid is not None or len(args) != 1:
  2452. self.raise_unsupported(
  2453. "Triton kernels should be called with only a single grid"
  2454. )
  2455. return type(variable)(
  2456. kernel=variable.kernel,
  2457. kernel_idx=variable.kernel_idx,
  2458. grid=args[0],
  2459. kernel_source=variable.source,
  2460. )
  2461. def call_HOP(
  2462. self,
  2463. variable: "TritonKernelVariable",
  2464. grids: Any,
  2465. combined_args: dict[str, Any],
  2466. tx: "InstructionTranslator",
  2467. ) -> "variables.ConstantVariable":
  2468. from .dicts import ConstDictVariable
  2469. # as we can only pass tensors as non-const args in fx graph,
  2470. # here we replace TMA descriptors
  2471. # (TMADescriptorExperimentalVariable and TMADescriptorStableVariable
  2472. # instances) with the underlying tensors, while moving the
  2473. # TMA descriptor-related metadata to a separate argument,
  2474. # so that we can reconstruct the TMA descriptors downstream
  2475. tma_descriptor_metadata: TMADescriptorMetadata = {}
  2476. for k in list(combined_args.keys()):
  2477. v = combined_args[k]
  2478. if isinstance(
  2479. v, (TMADescriptorExperimentalVariable, TMADescriptorStableVariable)
  2480. ):
  2481. tma_descriptor_metadata[k] = v.to_metadata()
  2482. combined_args[k] = v.get_tensor()
  2483. combined_args_vt = {
  2484. variables.ConstantVariable.create(k): v for k, v in combined_args.items()
  2485. }
  2486. from torch._higher_order_ops.triton_kernel_wrap import (
  2487. kernel_side_table,
  2488. triton_kernel_wrapper_mutation,
  2489. )
  2490. # Combine args and kwargs and pass as a dict so that if user defined triton
  2491. # kernel uses variables as 'grid' or 'kernel', it does not conflict with
  2492. # parameters of the wrapper function
  2493. constant_args = {
  2494. k: v.as_python_constant()
  2495. for k, v in combined_args.items()
  2496. if isinstance(v, VariableTracker) and v.is_python_constant()
  2497. }
  2498. non_constant_args = {
  2499. k: v
  2500. for k, v in combined_args_vt.items()
  2501. if not (isinstance(v, VariableTracker) and v.is_python_constant())
  2502. }
  2503. for v in non_constant_args.values():
  2504. v = v.realize()
  2505. if not (v.is_tensor() or v.is_symnode_like()):
  2506. self.raise_unsupported(
  2507. f"Unexpected argument type for a Triton kernel: {repr(v)}."
  2508. )
  2509. constant_args_idx = kernel_side_table.add_constant_args(constant_args)
  2510. meta = ConstDictVariable(non_constant_args, dict)
  2511. tx.output.create_proxy(
  2512. "call_function",
  2513. triton_kernel_wrapper_mutation,
  2514. (),
  2515. {
  2516. "kernel_idx": variable.kernel_idx,
  2517. "constant_args_idx": constant_args_idx,
  2518. "grid": grids,
  2519. "tma_descriptor_metadata": tma_descriptor_metadata,
  2520. "kwargs": meta.as_proxy(),
  2521. },
  2522. )
  2523. return variables.ConstantVariable(
  2524. None,
  2525. )
  2526. dynamo_triton_hopifier_singleton = DynamoTritonHOPifier()
  2527. class TritonKernelVariable(VariableTracker):
  2528. grid: "TritonGridType"
  2529. kernel: "TritonKernelType"
  2530. kernel_idx: int | None
  2531. kernel_source: "AttrSource"
  2532. def __init__(
  2533. self, kernel: Any, kernel_idx: int | None, grid: Any, **kwargs: Any
  2534. ) -> None:
  2535. self.kernel_source = kwargs.pop("kernel_source", None)
  2536. super().__init__(**kwargs)
  2537. dynamo_triton_hopifier_singleton.init_variable(self, kernel, kernel_idx, grid)
  2538. def call_function(
  2539. self,
  2540. tx: "InstructionTranslator",
  2541. args: Sequence[VariableTracker],
  2542. kwargs: dict[str, VariableTracker],
  2543. ) -> VariableTracker:
  2544. return dynamo_triton_hopifier_singleton.call_triton_kernel( # type: ignore[return-value]
  2545. self, args, kwargs, tx
  2546. )
  2547. def call_method(
  2548. self,
  2549. tx: "InstructionTranslator",
  2550. name: str,
  2551. args: list[VariableTracker],
  2552. kwargs: dict[str, VariableTracker],
  2553. ) -> VariableTracker:
  2554. if name == "__getitem__":
  2555. return dynamo_triton_hopifier_singleton.call_getitem(self, args)
  2556. elif name == "run":
  2557. return dynamo_triton_hopifier_singleton.call_run(self, args, kwargs, tx) # type: ignore[return-value]
  2558. # Bail out to parent's implementation
  2559. return super().call_method(tx, name, args, kwargs)
  2560. def specialize_symbolic(self, arg: Any) -> Any:
  2561. from .constant import ConstantVariable
  2562. from .tensor import SymNodeVariable
  2563. # See [Note: Specialize tl.constexpr args in user-defined triton kernels]
  2564. if isinstance(arg, SymNodeVariable):
  2565. return ConstantVariable.create(arg.evaluate_expr())
  2566. return arg
  2567. class TMADescriptorExperimentalVariable(VariableTracker):
  2568. def __init__(
  2569. self,
  2570. data_ptr: "variables.DataPtrVariable",
  2571. dims: list[VariableTracker],
  2572. block_dims: list[VariableTracker],
  2573. element_size: VariableTracker,
  2574. **kwargs: Any,
  2575. ) -> None:
  2576. assert isinstance(data_ptr, variables.DataPtrVariable)
  2577. super().__init__(**kwargs)
  2578. self.data_ptr = data_ptr
  2579. self.dims = dims
  2580. self.block_dims = block_dims
  2581. self.element_size = element_size
  2582. def to_metadata(self) -> Any:
  2583. return create_tma_experimental_metadata(
  2584. [dim.as_proxy() for dim in self.dims],
  2585. [dim.as_proxy() for dim in self.block_dims],
  2586. self.element_size.as_proxy(),
  2587. )
  2588. def reconstruct(self, codegen: "PyCodegen") -> None:
  2589. codegen.add_push_null(
  2590. lambda: codegen.load_import_from(
  2591. "triton.tools.experimental_descriptor",
  2592. f"create_{len(self.dims)}d_tma_descriptor",
  2593. )
  2594. )
  2595. self.data_ptr.reconstruct(codegen)
  2596. args = [*self.dims, *self.block_dims, self.element_size]
  2597. codegen.foreach(args)
  2598. codegen.call_function(len(args) + 1, False)
  2599. def get_tensor(self) -> VariableTracker:
  2600. return self.data_ptr.from_tensor
  2601. class TMADescriptorStableVariable(VariableTracker):
  2602. def __init__(
  2603. self,
  2604. tensor: "TensorVariable",
  2605. block_shape: "ListVariable",
  2606. **kwargs: Any,
  2607. ) -> None:
  2608. assert tensor.is_tensor()
  2609. super().__init__(**kwargs)
  2610. self.tensor = tensor
  2611. self.block_shape = block_shape
  2612. def to_metadata(self) -> Any:
  2613. return create_tma_stable_metadata(
  2614. self.block_shape.as_proxy(),
  2615. )
  2616. def reconstruct(self, codegen: "PyCodegen") -> None:
  2617. codegen.add_push_null(
  2618. lambda: codegen.load_import_from(
  2619. "triton.tools.tensor_descriptor",
  2620. "TensorDescriptor",
  2621. )
  2622. )
  2623. codegen.load_method("from_tensor")
  2624. self.tensor.reconstruct(codegen)
  2625. codegen(self.block_shape)
  2626. codegen.call_method(2)
  2627. def get_tensor(self) -> Any:
  2628. return self.tensor
  2629. class CreateTMADescriptorExperimentalVariable(VariableTracker):
  2630. def __init__(
  2631. self,
  2632. rank: int,
  2633. **kwargs: Any,
  2634. ) -> None:
  2635. assert rank in (1, 2)
  2636. super().__init__(**kwargs)
  2637. self.rank = rank
  2638. def call_function(
  2639. self,
  2640. tx: "InstructionTranslator",
  2641. args: Sequence[VariableTracker],
  2642. kwargs: dict[str, VariableTracker],
  2643. ) -> VariableTracker:
  2644. ptr = kwargs["ptr"] if "ptr" in kwargs else args[0]
  2645. if not isinstance(ptr, variables.DataPtrVariable):
  2646. unimplemented(
  2647. gb_type="invalid ptr argument for create_tma_descriptor",
  2648. context=f"args = {args}, kwargs = {kwargs}",
  2649. explanation=f"Expected `ptr` argument of `create_{self.rank}d_tma_descriptor`"
  2650. "to be from a `.data_ptr()` call, represented internally by `DataPtrVariable`",
  2651. hints=[
  2652. "`torch.compile` may fail to internally represent result of `.data_ptr()` "
  2653. "with `DataPtrVariable` due to a graph break between the `.data_ptr()` call and "
  2654. f"`create_{self.rank}d_tma_descriptor`. Please ensure there were no graph breaks "
  2655. "between these two calls.",
  2656. ],
  2657. )
  2658. if self.rank == 1:
  2659. if len(args) + len(kwargs) != 4:
  2660. raise_type_error_exc(
  2661. tx,
  2662. f"TMA metadata rank=1 requires exactly 4 arguments, got {len(args) + len(kwargs)}",
  2663. )
  2664. dims = [
  2665. kwargs["dim"] if "dim" in kwargs else args[1],
  2666. ]
  2667. block_dims = [
  2668. kwargs["block_dim"] if "block_dim" in kwargs else args[2],
  2669. ]
  2670. else:
  2671. if len(args) + len(kwargs) != 6:
  2672. raise_type_error_exc(
  2673. tx,
  2674. f"TMA metadata rank=2 requires exactly 6 arguments, got {len(args) + len(kwargs)}",
  2675. )
  2676. dims = [
  2677. kwargs["dim1"] if "dim1" in kwargs else args[1],
  2678. kwargs["dim0"] if "dim0" in kwargs else args[2],
  2679. ]
  2680. block_dims = [
  2681. kwargs["block_dim1"] if "block_dim1" in kwargs else args[3],
  2682. kwargs["block_dim0"] if "block_dim0" in kwargs else args[4],
  2683. ]
  2684. element_size = kwargs["element_size"] if "element_size" in kwargs else args[-1]
  2685. # to make pyrefy happy
  2686. assert isinstance(ptr, variables.DataPtrVariable)
  2687. return TMADescriptorExperimentalVariable(
  2688. data_ptr=ptr,
  2689. dims=dims,
  2690. block_dims=block_dims,
  2691. element_size=element_size,
  2692. )
  2693. class CreateTMADescriptorStableVariable(VariableTracker):
  2694. def call_function(
  2695. self,
  2696. tx: "InstructionTranslator",
  2697. args: Sequence[VariableTracker],
  2698. kwargs: dict[str, VariableTracker],
  2699. ) -> VariableTracker:
  2700. tensor = kwargs["tensor"] if "tensor" in kwargs else args[0]
  2701. block_shape = kwargs["block_shape"] if "block_shape" in kwargs else args[1]
  2702. return TMADescriptorStableVariable(
  2703. tensor=tensor, # type: ignore[arg-type]
  2704. block_shape=block_shape, # type: ignore[arg-type]
  2705. )
  2706. class PyTreeGetNodeTypeFunctionVariable(UserFunctionVariable):
  2707. """
  2708. `torch.utils._pytree._get_node_type` function is very hot function. We want to special case it to reduce Dynamo tracing time.
  2709. def _get_node_type(tree: Any) -> Any:
  2710. node_type = type(tree)
  2711. # All namedtuple types are implicitly registered as pytree nodes.
  2712. # XXX: Other parts of the codebase expect namedtuple types always return
  2713. # `namedtuple` instead of the actual namedtuple type. Even if the type
  2714. # is explicitly registered.
  2715. if is_namedtuple_class(node_type):
  2716. return namedtuple
  2717. return node_type
  2718. """
  2719. def call_function(
  2720. self,
  2721. tx: "InstructionTranslator",
  2722. args: Sequence[VariableTracker],
  2723. kwargs: dict[str, VariableTracker],
  2724. ) -> VariableTracker:
  2725. if len(args) != 1:
  2726. raise_type_error_exc(
  2727. tx,
  2728. f"pytree_get_node_type requires exactly 1 argument, got {len(args)}",
  2729. )
  2730. type_source = None
  2731. if args[0].source:
  2732. install_guard(args[0].source.make_guard(GuardBuilder.TYPE_MATCH))
  2733. type_source = TypeSource(args[0].source)
  2734. python_type = args[0].python_type()
  2735. if is_namedtuple_class(python_type):
  2736. type_source = AttrSource(ImportSource("collections"), "namedtuple")
  2737. return VariableTracker.build(tx, namedtuple, type_source)
  2738. return VariableTracker.build(tx, python_type, source=type_source)
  2739. class PyTreeTreeIsLeafFunctionVariable(UserFunctionVariable):
  2740. """
  2741. `torch.utils._pytree.tree_is_leaf` function is a hot function. We want to special case it to reduce Dynamo tracing time.
  2742. def tree_is_leaf(
  2743. tree: PyTree,
  2744. is_leaf: Callable[[PyTree], bool] | None = None,
  2745. ) -> bool:
  2746. if is_leaf is not None and is_leaf(tree):
  2747. return True
  2748. return _get_node_type(tree) not in SUPPORTED_NODES
  2749. When is_leaf is None (the common case), we can optimize by not tracing into the function.
  2750. When is_leaf is not None, we fall back to regular tracing since it requires executing user code.
  2751. """
  2752. def call_function(
  2753. self,
  2754. tx: "InstructionTranslator",
  2755. args: Sequence[VariableTracker],
  2756. kwargs: dict[str, VariableTracker],
  2757. ) -> VariableTracker:
  2758. # tree_is_leaf(tree, is_leaf=None)
  2759. if len(args) < 1 or len(args) > 2:
  2760. raise_type_error_exc(
  2761. tx,
  2762. f"tree_is_leaf requires 1 or 2 arguments, got {len(args)}",
  2763. )
  2764. # Check if is_leaf parameter is provided
  2765. is_leaf = kwargs.get("is_leaf", CONSTANT_VARIABLE_NONE)
  2766. if len(args) == 2:
  2767. is_leaf = args[1]
  2768. if not is_leaf.is_constant_none():
  2769. return super().call_function(tx, args, kwargs)
  2770. # Optimize the case where is_leaf is None
  2771. # return _get_node_type(tree) not in SUPPORTED_NODES
  2772. tree = args[0]
  2773. node_type_var = PyTreeGetNodeTypeFunctionVariable(
  2774. torch.utils._pytree._get_node_type
  2775. ).call_function(tx, [tree], {})
  2776. # If the SUPPORTED_NODES was seen earlier and mutated, there would be a
  2777. # source and that will give us the mutated SUPPORTED_NODES.
  2778. supported_nodes_var = VariableTracker.build(
  2779. tx,
  2780. torch.utils._pytree.SUPPORTED_NODES,
  2781. source=get_pytree_SUPPORTED_NODES_source(),
  2782. )
  2783. out = supported_nodes_var.call_method(tx, "__contains__", [node_type_var], {})
  2784. return ConstantVariable.create(not out.value)
  2785. class SparseTensorCreationSkipVariable(SkipFunctionVariable):
  2786. """
  2787. Skip variable for sparse tensor factory functions with clear messaging regarding lack of support.
  2788. """
  2789. def __init__(self, value: Any, **kwargs: Any) -> None:
  2790. reason = "sparse tensor creation is not supported in torch.compile"
  2791. super().__init__(value, reason=reason, **kwargs)
  2792. def call_function(
  2793. self,
  2794. tx: "InstructionTranslator",
  2795. args: Sequence[VariableTracker],
  2796. kwargs: dict[str, VariableTracker],
  2797. ) -> VariableTracker:
  2798. from .. import graph_break_hints
  2799. fn_name = getattr(self.value, "__name__", str(self.value))
  2800. unimplemented(
  2801. gb_type="Sparse tensor creation not supported",
  2802. context=f"function: {fn_name}",
  2803. explanation=(
  2804. f"torch.compile does not support sparse tensor creation functions like {fn_name}. "
  2805. "Sparse tensors require specialized handling that is not yet implemented in the compiler."
  2806. ),
  2807. hints=[*graph_break_hints.SPARSE_TENSOR],
  2808. )
  2809. class TritonSetAllocatorSkipVariable(SkipFunctionVariable):
  2810. """
  2811. Skip variable for triton.set_allocator with a clear message to move it outside the compiled region.
  2812. """
  2813. def __init__(self, value: Any, **kwargs: Any) -> None:
  2814. reason = "triton.set_allocator is not supported inside torch.compile"
  2815. super().__init__(value, reason=reason, **kwargs)
  2816. def call_function(
  2817. self,
  2818. tx: "InstructionTranslator",
  2819. args: Sequence[VariableTracker],
  2820. kwargs: dict[str, VariableTracker],
  2821. ) -> VariableTracker:
  2822. unimplemented(
  2823. gb_type="triton.set_allocator not supported",
  2824. context="triton.set_allocator called inside compiled region",
  2825. explanation=(
  2826. "triton.set_allocator is not supported inside torch.compile. "
  2827. "It modifies global Triton allocator state and cannot be traced."
  2828. ),
  2829. hints=[
  2830. "Move triton.set_allocator() outside of the torch.compile region "
  2831. "(call it before the compiled function)."
  2832. ],
  2833. )