base.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930
  1. """
  2. Core variable tracking functionality for Dynamo. This module defines the fundamental
  3. classes and systems used to track and manage variables during Dynamo's operation.
  4. The module provides:
  5. 1. VariableTracker - The base class for tracking variables during compilation
  6. 2. MutationType system - Classes for tracking and managing mutations to variables
  7. 3. Source type management - Utilities for tracking variable origins and scope
  8. 4. Variable state management - Tools for managing variable state and transformations
  9. These components form the foundation of Dynamo's variable handling system,
  10. enabling accurate tracking and transformation of Python code into optimized
  11. computations.
  12. """
  13. import collections
  14. import functools
  15. import logging
  16. from collections.abc import Callable, ItemsView, KeysView, Sequence, ValuesView
  17. from contextvars import ContextVar
  18. from enum import Enum
  19. from typing import Any, NoReturn, Optional, TYPE_CHECKING
  20. from torch._guards import Guard
  21. from torch.fx.proxy import Node
  22. from .. import graph_break_hints, variables
  23. from ..current_scope_id import current_scope_id
  24. from ..exc import raise_observed_exception, unimplemented
  25. from ..guards import GuardBuilder, install_guard
  26. from ..source import AttrSource, Source
  27. from ..utils import cmp_name_to_op_mapping, istype
  28. if TYPE_CHECKING:
  29. from ..codegen import PyCodegen
  30. from ..symbolic_convert import InstructionTranslator
  31. from .constant import ConstantVariable
  32. from .functions import UserFunctionVariable
  33. log = logging.getLogger(__name__)
  34. # Tracks active method calls on VariableTracker instances to detect self-referential
  35. # calls (e.g., as_python_constant on a list that contains itself). Maps
  36. # (id(instance), method_name) tuples to track which calls are in progress.
  37. _vt_active_calls: ContextVar[set[tuple[int, str]] | None] = ContextVar(
  38. "_vt_active_calls", default=None
  39. )
  40. class SourceType(Enum):
  41. """
  42. This Enum divides VariableTracker into 2 cases, depending on the variable
  43. it represents:
  44. - already existed that Dynamo began tracking while introspection (Existing)
  45. - is a new variable that is created during Dynamo introspection (New)
  46. In general, we have these invariants:
  47. 1. for `VariableTracker` associated with `Existing`, its `source` field must not be None.
  48. 2. for `VariableTracker` associated with `New`, most of the time its
  49. `source` field is None, except for cases like side effect codegen for
  50. `AttributeMutationNew`, during which we generate a
  51. `LocalSource('tmp...')` for such variable, to facilitate codegen.
  52. """
  53. Existing = 0
  54. New = 1
  55. class MutationType:
  56. """
  57. Base class for Variable.mutation_type. It encodes information about
  58. 1. The type of mutation Dynamo allows on the variable.
  59. 2. Whether the value represented by this variable already existed before
  60. Dynamo tracing.
  61. """
  62. def __init__(self, typ: SourceType) -> None:
  63. # In HigherOrderOperator tracing, we need to distinguish
  64. # between MutationTypes inside the HigherOrderOperator and
  65. # ones outside it. For example, it is not safe to mutate
  66. # `a` in the following example because it was constructed
  67. # in a different scope.
  68. #
  69. # def f(x):
  70. # a = 1
  71. # def g(x):
  72. # nonlocal a
  73. # a = 2
  74. # return x
  75. # return wrap(g, x) + a
  76. #
  77. # We use self.scope to distinguish this.
  78. # scope == 0: The object was an existing variable
  79. # scope == 1: The object was created while Dynamo
  80. # was introspecting a function
  81. # (and no HigherOrderOps were involved)
  82. # scope >= 2: The object was created through
  83. # Dynamo introspection of a HigherOrderOp.
  84. # The exact number corresponds to the level
  85. # of nested HigherOrderOps.
  86. if typ is SourceType.Existing:
  87. self.scope = 0
  88. elif typ is SourceType.New:
  89. self.scope = current_scope_id()
  90. else:
  91. unimplemented(
  92. gb_type="Unsupported SourceType",
  93. context=f"MutationType.__init__ {self} {typ}",
  94. explanation=f"Dynamo does not support the type `{typ}`",
  95. hints=[
  96. "This branch is not supposed to be reachable.",
  97. *graph_break_hints.DYNAMO_BUG,
  98. ],
  99. )
  100. class ValueMutationNew(MutationType):
  101. """
  102. This case of VariableTracker.mutation_type marker indicates
  103. 1. Dynamo allows mutation on the value itself (rather than its attributes).
  104. 2. The value is created by the bytecode Dynamo is tracing through.
  105. For instance, Dynamo could model a newly created list with this marker,
  106. indicating that while we need to model mutations to this list, we don't have
  107. to emit bytecode for these mutations if the list doesn't escape into the
  108. Python world.
  109. """
  110. def __init__(self) -> None:
  111. super().__init__(SourceType.New)
  112. def __hash__(self) -> int:
  113. return id(self)
  114. def __eq__(self, other: object) -> bool:
  115. return self is other
  116. class ValueMutationExisting(MutationType):
  117. """
  118. This case of VariableTracker.mutation_type marker indicates
  119. 1. Dynamo allows mutation on the value itself (rather than its attributes).
  120. 2. The value exists before Dynamo tracing started.
  121. For instance, Dynamo could model a pre-existing list with this marker,
  122. indicating that if we encounter mutations to this list, we need to buffer
  123. and re-apply those mutations after the graph runs, since the list might be
  124. used afterwards in Python.
  125. """
  126. # A flag to indicate whether mutation happened on the associated
  127. # `VariableTracker`. This enables SideEffects to accurately and quickly
  128. # filter out which pre-existing values it needs to generate mutation for.
  129. is_modified: bool
  130. def __init__(self, is_modified: bool = False) -> None:
  131. super().__init__(SourceType.Existing)
  132. self.is_modified = is_modified
  133. class AttributeMutation(MutationType):
  134. """
  135. This case of VariableTracker.mutation_type marker indicates that Dynamo
  136. allows mutation on the value's attributes.
  137. """
  138. class AttributeMutationExisting(AttributeMutation):
  139. """
  140. This case of VariableTracker.mutation_type marker indicates
  141. 1. Dynamo allows mutation on the value's attributes.
  142. 2. The value exists before Dynamo tracing started.
  143. For instance, Dynamo could model a pre-existing object with this marker,
  144. indicating that if we encounter mutations to this object, we need to buffer
  145. then re-apply those mutations after the graph runs, since the object might
  146. be used afterwards in Python.
  147. """
  148. def __init__(self) -> None:
  149. super().__init__(SourceType.Existing)
  150. class AttributeMutationNew(AttributeMutation):
  151. """
  152. This case of VariableTracker.mutation_type marker indicates
  153. 1. Dynamo allows mutation on the value's attributes.
  154. 2. The value is created by the bytecode Dynamo is tracing through.
  155. For instance, Dynamo could model a newly created object with this marker,
  156. indicating that while we need to model mutations to this object, we don't
  157. have to emit bytecode for these mutations if the object doesn't escape into
  158. the Python world.
  159. """
  160. def __init__(self, cls_source: Optional[Source] = None) -> None:
  161. super().__init__(SourceType.New)
  162. self.cls_source = cls_source
  163. def _is_top_level_scope(scope_id: int) -> bool:
  164. return scope_id == 1
  165. def is_side_effect_safe(m: MutationType) -> bool:
  166. scope_id = current_scope_id()
  167. # In the top-level scope (if no HigherOrderOperators are involved),
  168. # we are allowed to modify variables created in this scope as well
  169. # as existing variables.
  170. if _is_top_level_scope(scope_id):
  171. return True
  172. # Otherwise, only allow local mutation of variables created in the current scope
  173. return m.scope == scope_id
  174. # This helps users of `as_python_constant` to catch unimplemented error with
  175. # more information; it inherits `NotImplementedError` for backward
  176. # compatibility reasons.
  177. class AsPythonConstantNotImplementedError(NotImplementedError):
  178. vt: "VariableTracker"
  179. def __init__(self, vt: "VariableTracker", msg: str | None = None) -> None:
  180. msg = f"{vt} is not a constant" if msg is None else msg
  181. super().__init__(msg)
  182. self.vt = vt
  183. class VariableTrackerMeta(type):
  184. all_subclasses: list[type] = []
  185. def __new__(
  186. mcs: type, name: str, bases: tuple[type, ...], attrs: dict[str, Any]
  187. ) -> type:
  188. # Determine which metaclass to use based on the class attributes
  189. # Classes with _no_implicit_realize = True should NOT implicitly realize
  190. # (they need standard isinstance behavior to avoid infinite recursion)
  191. # Check if any base class has _no_implicit_realize set, or if it's in attrs
  192. no_implicit_realize = attrs.get("_no_implicit_realize", False) or any(
  193. getattr(base, "_no_implicit_realize", False) for base in bases
  194. )
  195. if no_implicit_realize or name == "VariableTracker":
  196. # Use base VariableTrackerMeta (no custom __instancecheck__)
  197. return super().__new__(VariableTrackerMeta, name, bases, attrs)
  198. else:
  199. # Use ImplicitRealizingVariableTrackerMeta for all other subclasses
  200. return super().__new__(
  201. ImplicitRealizingVariableTrackerMeta, name, bases, attrs
  202. )
  203. def __init__(
  204. cls: type, name: str, bases: tuple[type, ...], attrs: dict[str, Any]
  205. ) -> None:
  206. super().__init__(name, bases, attrs) # type: ignore[misc]
  207. VariableTrackerMeta.all_subclasses.append(cls)
  208. class ImplicitRealizingVariableTrackerMeta(VariableTrackerMeta):
  209. def __instancecheck__(self, instance: object) -> bool:
  210. """Make isinstance work with LazyVariableTracker"""
  211. if instancecheck(LazyVariableTracker, instance):
  212. return instance.lazy_isinstance(self) # pyrefly: ignore[missing-attribute]
  213. return instancecheck(self, instance)
  214. class VariableTracker(metaclass=VariableTrackerMeta):
  215. """
  216. Base class for tracked locals and stack values
  217. VariableTracker instances are immutable and should be copied in
  218. order to change them.
  219. Prefer the factory function VariableTracker.build() over VariableTracker.__init__().
  220. """
  221. # fields to leave unmodified in apply()
  222. _nonvar_fields = {
  223. "value",
  224. "guards",
  225. "source",
  226. "mutation_type",
  227. "parents_tracker",
  228. "user_code_variable_name",
  229. }
  230. def clone(self, **kwargs: Any) -> "VariableTracker":
  231. """Shallow copy with some (optional) changes"""
  232. args = dict(self.__dict__)
  233. args.update(kwargs)
  234. return self.__class__(**args)
  235. @classmethod
  236. def visit(
  237. cls,
  238. fn: Callable[["VariableTracker"], None],
  239. value: Any,
  240. cache: Optional[dict[int, Any]] = None,
  241. ) -> None:
  242. """
  243. Walk value and call fn on all the VariableTracker instances
  244. """
  245. if cache is None:
  246. cache = {}
  247. idx = id(value)
  248. if idx in cache:
  249. return
  250. # save `value` to keep it alive and ensure id() isn't reused
  251. cache[idx] = value
  252. if isinstance(value, VariableTracker):
  253. value = value.unwrap()
  254. fn(value)
  255. value = value.unwrap() # calling fn() might have realized it
  256. nonvars = value._nonvar_fields
  257. for key, subvalue in value.__dict__.items():
  258. if key not in nonvars:
  259. cls.visit(fn, subvalue, cache)
  260. elif istype(value, (list, tuple)):
  261. for subvalue in value:
  262. cls.visit(fn, subvalue, cache)
  263. elif istype(value, (dict, collections.OrderedDict)):
  264. for subvalue in value.values():
  265. cls.visit(fn, subvalue, cache)
  266. def __repr__(self) -> str:
  267. return f"{self.__class__.__name__}()"
  268. def debug_repr(self) -> str:
  269. # Intended to be overridden to provide more info
  270. try:
  271. return repr(self.as_python_constant())
  272. except NotImplementedError:
  273. return repr(self)
  274. def python_type(self) -> type:
  275. """
  276. Abstract method to be implemented by subclasses of VariableTracker.
  277. This method should return the type represented by the instance of the subclass.
  278. The purpose is to provide a standardized way to retrieve the Python type information
  279. of the variable being tracked.
  280. Returns:
  281. type: The Python type (such as int, str, list, etc.) of the variable tracked by
  282. the subclass. If the type cannot be determined or is not relevant,
  283. leaving it undefined or invoking super() is always sound.
  284. Note:
  285. This is an abstract method and may be overridden in subclasses.
  286. Example:
  287. class SetVariable(VariableTracker):
  288. def python_type(self):
  289. return set
  290. Raises:
  291. NotImplementedError: If the method is not implemented in a subclass.
  292. """
  293. try:
  294. return type(self.as_python_constant())
  295. except NotImplementedError:
  296. raise NotImplementedError(f"{self} has no type") from None
  297. def python_type_name(self) -> str:
  298. try:
  299. return self.python_type().__name__
  300. except NotImplementedError:
  301. return "<unknown type>"
  302. def as_python_constant(self) -> Any:
  303. """For constants"""
  304. raise AsPythonConstantNotImplementedError(self)
  305. def guard_as_python_constant(self) -> Any:
  306. """Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants"""
  307. try:
  308. return self.as_python_constant()
  309. except NotImplementedError:
  310. unimplemented(
  311. gb_type="Not a Python constant",
  312. context=f"guard_as_python_constant {self}",
  313. explanation=f"Failed to convert {self} into a Python constant.",
  314. hints=[],
  315. )
  316. def is_python_constant(self) -> bool:
  317. try:
  318. self.as_python_constant()
  319. return True
  320. except NotImplementedError:
  321. return False
  322. def is_constant_match(self, *values: Any) -> bool:
  323. """
  324. Check if this variable is a python constant matching one of the given values.
  325. Examples:
  326. var.is_constant_match(None) # True if var is constant None
  327. var.is_constant_match(True, False) # True if var is constant True or False
  328. var.is_constant_match(NotImplemented) # True if var is constant NotImplemented
  329. """
  330. return False
  331. def is_constant_none(self) -> bool:
  332. """Check if this variable is a constant None value."""
  333. return False
  334. def make_guard(self, fn: Callable[..., Any]) -> Guard:
  335. if self.source:
  336. return self.source.make_guard(fn)
  337. raise NotImplementedError
  338. # TODO[@lucaskabela] - change this type to `InstructionTranslatorBase`
  339. # and cascade that (large blast radius)
  340. def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any:
  341. """getattr(self, name) returning a python constant"""
  342. raise NotImplementedError
  343. def is_symnode_like(self) -> bool:
  344. """Return True for values that can participate in SymNode operations"""
  345. return False
  346. def is_tensor(self) -> bool:
  347. """Return True for TensorVariable instances"""
  348. return False
  349. def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
  350. """getattr(self, name) returning a new variable"""
  351. value = self.const_getattr(tx, name)
  352. if not variables.ConstantVariable.is_literal(value):
  353. raise NotImplementedError
  354. source = self.source and AttrSource(self.source, name)
  355. if source and not self.is_python_constant():
  356. # The second condition is to avoid guards on const getattr objects
  357. # like __code__.co_argcount
  358. install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH))
  359. return variables.ConstantVariable.create(value, source=source)
  360. def is_proxy(self) -> bool:
  361. try:
  362. self.as_proxy()
  363. return True
  364. except NotImplementedError:
  365. return False
  366. def as_proxy(self) -> Any:
  367. raise NotImplementedError(str(self))
  368. def maybe_fx_node(self) -> Optional[Node]:
  369. try:
  370. proxy = self.as_proxy()
  371. import torch.fx
  372. if isinstance(proxy, torch.fx.Proxy):
  373. return proxy.node
  374. return None
  375. except NotImplementedError:
  376. return None
  377. def _contains_self_reference(self) -> bool:
  378. """Check if this variable references itself (directly or indirectly)."""
  379. found_self = False
  380. def check(vt: "VariableTracker") -> None:
  381. nonlocal found_self
  382. if vt is self:
  383. found_self = True
  384. # unwrap first iteration - otherwise we can't detect if we revisit self
  385. for key, subvalue in self.__dict__.items():
  386. if key not in self._nonvar_fields:
  387. VariableTracker.visit(check, subvalue)
  388. return found_self
  389. def reconstruct(self, codegen: "PyCodegen") -> None:
  390. raise NotImplementedError
  391. def unpack_var_sequence(self, tx: Any) -> list["VariableTracker"]:
  392. raise NotImplementedError
  393. def force_unpack_var_sequence(self, tx: Any) -> list["VariableTracker"]:
  394. # like unpack_var_sequence, but should only be used when it is
  395. # safe to eagerly (vs. lazily) unpack this variable.
  396. # e.g. map(f, x) is normally evaluated lazily but sometimes
  397. # we want to force eager unpacking, e.g. when converting to a list.
  398. # NOTE: this method is allowed to mutate the VariableTracker, so
  399. # it should only be called once.
  400. return self.unpack_var_sequence(tx)
  401. def has_unpack_var_sequence(self, tx: Any) -> bool:
  402. try:
  403. self.unpack_var_sequence(tx)
  404. return True
  405. except NotImplementedError:
  406. return False
  407. # NB: don't call force_unpack_var_sequence, especially if it mutates!
  408. def has_force_unpack_var_sequence(self, tx: Any) -> bool:
  409. return self.has_unpack_var_sequence(tx)
  410. # Forces unpacking the var sequence while also applying a function to each element.
  411. # Only use when it is safe to eagerly unpack this variable (like force_unpack_var_sequence).
  412. # INVARIANT: variable must satisfy has_force_unpack_var_sequence() == True!
  413. def force_apply_to_var_sequence(
  414. self, tx: Any, fn: Callable[["VariableTracker"], Any]
  415. ) -> None:
  416. assert self.has_force_unpack_var_sequence(tx)
  417. for v in self.unpack_var_sequence(tx):
  418. fn(v)
  419. def call_obj_hasattr(
  420. self, tx: "InstructionTranslator", name: str
  421. ) -> "ConstantVariable":
  422. unimplemented(
  423. gb_type="Unsupported hasattr call",
  424. context=f"call_obj_hasattr {self} {name}",
  425. explanation=f"Dynamo does not know how to trace the function `{self.debug_repr()}`",
  426. hints=[
  427. f"Avoid calling `hasattr({self.__class__.__name__}, {name})` in your code.",
  428. *graph_break_hints.SUPPORTABLE,
  429. ],
  430. )
  431. def call_function(
  432. self,
  433. tx: Any,
  434. args: Sequence["VariableTracker"],
  435. kwargs: dict[str, "VariableTracker"],
  436. ) -> "VariableTracker":
  437. unimplemented(
  438. gb_type="Unsupported function call",
  439. context=f"call_function {self} {args} {kwargs}",
  440. explanation=f"Dynamo does not know how to trace the function `{self.debug_repr()}`",
  441. hints=[
  442. f"Avoid calling `{self.debug_repr()}` in your code.",
  443. "Please report an issue to PyTorch.",
  444. ],
  445. )
  446. def call_method(
  447. self,
  448. tx: Any,
  449. name: str,
  450. args: list["VariableTracker"],
  451. kwargs: dict[str, "VariableTracker"],
  452. ) -> "VariableTracker":
  453. if name == "__len__" and self.has_unpack_var_sequence(tx):
  454. assert not (args or kwargs)
  455. return variables.ConstantVariable.create(len(self.unpack_var_sequence(tx)))
  456. elif (
  457. name == "__getattr__"
  458. and len(args) == 1
  459. and args[0].is_python_constant()
  460. and not kwargs
  461. ):
  462. return self.var_getattr(tx, args[0].as_python_constant())
  463. elif name in cmp_name_to_op_mapping and len(args) == 1 and not kwargs:
  464. other = args[0]
  465. if not isinstance(self, type(other)) and not (
  466. isinstance(self, variables.GetAttrVariable)
  467. or isinstance(other, variables.GetAttrVariable)
  468. ):
  469. # NB: GetAttrVariable is a special case because sometimes an
  470. # object can map to GetAttrVariable but other time as
  471. # SkipFunctionVariable if it is an input to the compiled
  472. # function, e.g. tensor.data_ptr
  473. return variables.ConstantVariable.create(NotImplemented)
  474. # NB : Checking for mutation is necessary because we compare
  475. # constant values
  476. if (
  477. not self.is_python_constant()
  478. or not other.is_python_constant()
  479. or tx.output.side_effects.has_pending_mutation(self)
  480. or tx.output.side_effects.has_pending_mutation(other)
  481. ):
  482. unimplemented(
  483. gb_type="Builtin `operator.*` comparison with constant `self` failed",
  484. context=f"call_method {self} {name} {args} {kwargs}",
  485. explanation=f"Failed to compare {self} with {other}, "
  486. + f"because {other} is not a Python constant or its mutation check fails.",
  487. hints=[],
  488. )
  489. try:
  490. return variables.ConstantVariable.create(
  491. cmp_name_to_op_mapping[name](
  492. self.as_python_constant(), other.as_python_constant()
  493. )
  494. )
  495. except Exception as e:
  496. raise_observed_exception(
  497. type(e),
  498. tx,
  499. args=[list(map(variables.ConstantVariable.create, e.args))],
  500. )
  501. hints = [
  502. f"Avoid calling `{self.python_type_name()}.{name}` in your code.",
  503. "Please report an issue to PyTorch.",
  504. ]
  505. # additional hint for method calls on improperly constructed iterators
  506. if isinstance(self, variables.UserDefinedObjectVariable) and name in (
  507. "__iter__",
  508. "__next__",
  509. ):
  510. if isinstance(self.value, (KeysView, ItemsView, ValuesView)):
  511. hints.append(
  512. "Consider moving the creation of dict view object (e.g. `dict.keys()`, `dict.items()`,) "
  513. "to the compiled region, instead of passing it as an input to the compiled region."
  514. )
  515. hints.append(
  516. "Dynamo does not fully support tracing builtin iterators (e.g. `map`, `zip`, `enumerate`) "
  517. "passed in from uncompiled to compiled regions (e.g. `torch.compile(fn)(enumerate(...))`). "
  518. "This can happen unintentionally if a previous graph break happens with a builtin iterator "
  519. "in the local scope."
  520. )
  521. hints.append(
  522. "List/dict comprehensions in Python <= 3.11 result in implicit function calls, which Dynamo "
  523. "cannot trace as a top level frame. Possible workarounds are (1) use a loop instead of a comprehension, "
  524. "(2) fix any graph breaks in the function above the comprehension, (3) wrap the comprehension in a "
  525. "function, or (4) use Python 3.12+."
  526. )
  527. unimplemented(
  528. gb_type="Unsupported method call",
  529. context=f"call_method {self} {name} {args} {kwargs}",
  530. explanation=f"Dynamo does not know how to trace method `{name}` of class `{self.python_type_name()}`",
  531. hints=hints,
  532. )
  533. def call_tree_map(
  534. self,
  535. tx: Any,
  536. tree_map_fn: "UserFunctionVariable",
  537. map_fn: "VariableTracker",
  538. rest: Sequence["VariableTracker"],
  539. tree_map_kwargs: dict[str, "VariableTracker"],
  540. ) -> "VariableTracker":
  541. """Performance optimization to implement optree.tree_map faster than tracing it"""
  542. is_leaf_var = tree_map_kwargs.get("is_leaf")
  543. if is_leaf_var is not None and not is_leaf_var.is_constant_none():
  544. pred_result = is_leaf_var.call_function(tx, [self], {})
  545. try:
  546. leaf_decision = pred_result.as_python_constant()
  547. except NotImplementedError:
  548. return self._tree_map_fallback(
  549. tx,
  550. tree_map_fn,
  551. map_fn,
  552. rest,
  553. tree_map_kwargs,
  554. )
  555. if leaf_decision:
  556. return map_fn.call_function(tx, [self, *rest], {})
  557. return self.call_tree_map_branch(
  558. tx,
  559. tree_map_fn,
  560. map_fn,
  561. rest,
  562. tree_map_kwargs,
  563. )
  564. def call_tree_map_branch(
  565. self,
  566. tx: Any,
  567. tree_map_fn: "UserFunctionVariable",
  568. map_fn: "VariableTracker",
  569. rest: Sequence["VariableTracker"],
  570. tree_map_kwargs: dict[str, "VariableTracker"],
  571. ) -> "VariableTracker":
  572. """Emulate optree.tree_map without is_leaf/none_is_leaf checks (handled above)"""
  573. return self._tree_map_fallback(
  574. tx,
  575. tree_map_fn,
  576. map_fn,
  577. rest,
  578. tree_map_kwargs,
  579. )
  580. def _tree_map_fallback(
  581. self,
  582. tx: Any,
  583. tree_map_fn: "UserFunctionVariable",
  584. map_fn: "VariableTracker",
  585. rest: Sequence["VariableTracker"],
  586. tree_map_kwargs: dict[str, "VariableTracker"],
  587. ) -> "VariableTracker":
  588. tree_map_fn_copy = tree_map_fn.clone()
  589. tree_map_fn_copy._maybe_call_tree_map_fastpath = lambda *args, **kwargs: None # type: ignore[missing-attribute]
  590. log.debug(
  591. "tree_map fastpath fallback triggered for %s (rest=%s, kwargs=%s)",
  592. self,
  593. rest,
  594. tree_map_kwargs,
  595. )
  596. return tree_map_fn_copy.call_function(
  597. tx,
  598. [map_fn, self, *rest],
  599. tree_map_kwargs,
  600. )
  601. def set_name_hint(self, name: str) -> None:
  602. pass
  603. def realize(self) -> "VariableTracker":
  604. """Used by LazyVariableTracker to build the real VariableTracker"""
  605. return self
  606. def unwrap(self) -> "VariableTracker":
  607. """Used by LazyVariableTracker to return the real VariableTracker if it already exists"""
  608. return self
  609. def is_realized(self) -> bool:
  610. """Used by LazyVariableTracker to indicate an unrealized node"""
  611. return True
  612. def next_variable(self, tx: Any) -> "VariableTracker":
  613. unimplemented(
  614. gb_type="Unsupported next() call",
  615. context=f"next({self})",
  616. explanation=f"Dynamo does not know how to trace calling `next()` on variable `{self}`.",
  617. hints=[*graph_break_hints.USER_ERROR],
  618. )
  619. def is_strict_mode(self, tx: Any) -> bool:
  620. return bool(tx.strict_checks_fn and tx.strict_checks_fn(self))
  621. def is_mutable(self) -> bool:
  622. """Whether Dynamo allows mutation on this variable."""
  623. return not self.is_immutable()
  624. def is_immutable(self) -> bool:
  625. """Whether Dynamo bans mutation on this variable."""
  626. return self.mutation_type is None
  627. @staticmethod
  628. def build(
  629. tx: Any,
  630. value: Any,
  631. source: Optional[Source] = None,
  632. realize: bool = False,
  633. ) -> Any:
  634. """Create a new VariableTracker from a value and optional Source"""
  635. if source is None:
  636. return builder.SourcelessBuilder.create(tx, value)
  637. elif realize:
  638. return builder.VariableBuilder(tx, source)(value)
  639. elif type(value) in variables.LazyConstantVariable.supported_types:
  640. # Use LazyConstantVariable for primitives to enable deferred
  641. # guard installation - constants that are just passed through
  642. # won't cause recompilation when their values change.
  643. return variables.LazyConstantVariable.create(value, source)
  644. else:
  645. return variables.LazyVariableTracker.create(value, source)
  646. def is_python_hashable(self) -> bool:
  647. """
  648. Unlike the variable tracker's own __hash__, this method checks whether
  649. the underlying Python object referenced by this variable tracker is hashable.
  650. """
  651. try:
  652. type_self = self.python_type()
  653. except NotImplementedError:
  654. type_self = type(self)
  655. unimplemented(
  656. gb_type="Dynamo cannot determine whether the underlying object is hashable",
  657. context=f"is_python_hashable {self}",
  658. explanation=f"Dynamo does not know whether the underlying python object for {self} is hashable",
  659. hints=[
  660. (
  661. f"Consider using a different type of object as the dictionary key instead of {type_self}."
  662. ),
  663. *graph_break_hints.SUPPORTABLE,
  664. ],
  665. )
  666. def get_python_hash(self) -> int:
  667. """
  668. Unlike the variable tracker’s own __hash__, this method is used by
  669. ConstDictVariableTracker to compute the hash of the underlying key object.
  670. """
  671. unimplemented(
  672. gb_type="Dynamo cannot determine the hash of an object",
  673. context=f"get_python_hash {self}",
  674. explanation=f"Dynamo does not know the hash of the underlying python object for {self}",
  675. hints=[
  676. (
  677. f"Consider using a different type of object as the dictionary key instead of {self.python_type()}."
  678. ),
  679. *graph_break_hints.SUPPORTABLE,
  680. ],
  681. )
  682. def is_python_equal(self, other: object) -> bool:
  683. """
  684. NB - Deliberately not overriding the __eq__ method because that can
  685. disable the __hash__ for the vt itself.
  686. """
  687. unimplemented(
  688. gb_type="Dynamo cannot determine the equality comparison of an object",
  689. context=f"is_python_equal {self}",
  690. explanation=f"Dynamo does not know the equality comparison of the underlying python object for {self}",
  691. hints=[
  692. (
  693. f"Consider using a different type of object as the dictionary key instead of {self.python_type()}."
  694. ),
  695. *graph_break_hints.SUPPORTABLE,
  696. ],
  697. )
  698. def __init__(
  699. self,
  700. *,
  701. source: Optional[Source] = None,
  702. mutation_type: Optional[MutationType] = None,
  703. ) -> None:
  704. super().__init__()
  705. self.source = source
  706. self.mutation_type = mutation_type
  707. # NOTE sometimes mutation_type is set afterwards for implementation
  708. # convenience, we don't validate those cases at the moment.
  709. if mutation_type is not None:
  710. if isinstance(mutation_type, (ValueMutationNew, AttributeMutationNew)):
  711. # If this fails, it's either
  712. # 1. one mistakenly passed in a source
  713. # 2. `mutation_type` is incorrect
  714. assert source is None
  715. else:
  716. assert isinstance(
  717. mutation_type, (ValueMutationExisting, AttributeMutationExisting)
  718. )
  719. # If this fails, it's either
  720. # 1. one forgot to pass in a source
  721. # 2. `mutation_type` is incorrect
  722. assert source is not None
  723. def __init_subclass__(cls, **kwargs: Any) -> None:
  724. """
  725. Wraps all subclasses' `as_python_constant` and `reconstruct` so that it cannot be
  726. called twice in the same call chain - i.e. self-referential objects.
  727. For `as_python_constant` - self-referential objects are NOT treated as constants.
  728. For `reconstruct` - we will graph break. The graph break can be avoided if the VT subclass
  729. can generate and cache itself before recursively `reconstruct`ing - see ListVariable for an example.
  730. """
  731. super().__init_subclass__(**kwargs)
  732. def as_python_constant_failure(self) -> NoReturn:
  733. raise AsPythonConstantNotImplementedError(
  734. self, msg=f"{self} is self-referential"
  735. )
  736. VariableTracker._add_call_once_guard(
  737. cls, "as_python_constant", as_python_constant_failure
  738. )
  739. def reconstruct_failure(self) -> NoReturn:
  740. unimplemented(
  741. gb_type="Reconstruction failure (self-referential)",
  742. context=str(self),
  743. explanation=f"Dynamo tried to reconstruct sourceless variable {self}, but it is self-referential. "
  744. "Dynamo must manually implement reconstruction rules for self-referentiable sourceless variables.",
  745. hints=[
  746. "If Dynamo is attempting to trace a return statement and your code is attempting to return a variable "
  747. "that Dynamo cannot reconstruct, then remove it from the return statement.",
  748. "Remove the self-reference in the variable. A self-referring list, for example, is `l = []; l.append(l)`.",
  749. *graph_break_hints.CAUSED_BY_EARLIER_GRAPH_BREAK,
  750. "Report an issue to PyTorch if you need self-referential reconstrtuction support.",
  751. ],
  752. )
  753. VariableTracker._add_call_once_guard(cls, "reconstruct", reconstruct_failure)
  754. @staticmethod
  755. def _add_call_once_guard(
  756. cls: type["VariableTracker"],
  757. method: str,
  758. callback: Callable[["VariableTracker"], Any],
  759. ) -> None:
  760. original_method = getattr(cls, method)
  761. if original_method is getattr(VariableTracker, method) or hasattr(
  762. original_method, "_call_once_guarded"
  763. ):
  764. return
  765. @functools.wraps(original_method)
  766. def guarded_method(self, *args: Any, **kwargs: Any) -> VariableTracker:
  767. active = _vt_active_calls.get()
  768. if active is None:
  769. active = set()
  770. _vt_active_calls.set(active)
  771. key = (id(self), method)
  772. if key in active:
  773. callback(self)
  774. active.add(key)
  775. try:
  776. return original_method(self, *args, **kwargs)
  777. finally:
  778. active.discard(key)
  779. guarded_method._call_once_guarded = True # pyrefly: ignore[missing-attribute]
  780. setattr(cls, method, guarded_method)
  781. def raise_type_error_exc(tx: Any, msg_str: str) -> NoReturn:
  782. msg = variables.ConstantVariable.create(msg_str)
  783. raise_observed_exception(TypeError, tx, args=[msg])
  784. def typestr(*objs: object) -> str:
  785. if len(objs) == 1:
  786. (obj,) = objs
  787. if isinstance(obj, VariableTracker):
  788. return str(obj)
  789. else:
  790. return type(obj).__name__
  791. else:
  792. return " ".join(map(typestr, objs))
  793. instancecheck = type.__instancecheck__
  794. from . import builder
  795. from .lazy import LazyVariableTracker