| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930 |
- """
- Core variable tracking functionality for Dynamo. This module defines the fundamental
- classes and systems used to track and manage variables during Dynamo's operation.
- The module provides:
- 1. VariableTracker - The base class for tracking variables during compilation
- 2. MutationType system - Classes for tracking and managing mutations to variables
- 3. Source type management - Utilities for tracking variable origins and scope
- 4. Variable state management - Tools for managing variable state and transformations
- These components form the foundation of Dynamo's variable handling system,
- enabling accurate tracking and transformation of Python code into optimized
- computations.
- """
- import collections
- import functools
- import logging
- from collections.abc import Callable, ItemsView, KeysView, Sequence, ValuesView
- from contextvars import ContextVar
- from enum import Enum
- from typing import Any, NoReturn, Optional, TYPE_CHECKING
- from torch._guards import Guard
- from torch.fx.proxy import Node
- from .. import graph_break_hints, variables
- from ..current_scope_id import current_scope_id
- from ..exc import raise_observed_exception, unimplemented
- from ..guards import GuardBuilder, install_guard
- from ..source import AttrSource, Source
- from ..utils import cmp_name_to_op_mapping, istype
- if TYPE_CHECKING:
- from ..codegen import PyCodegen
- from ..symbolic_convert import InstructionTranslator
- from .constant import ConstantVariable
- from .functions import UserFunctionVariable
- log = logging.getLogger(__name__)
- # Tracks active method calls on VariableTracker instances to detect self-referential
- # calls (e.g., as_python_constant on a list that contains itself). Maps
- # (id(instance), method_name) tuples to track which calls are in progress.
- _vt_active_calls: ContextVar[set[tuple[int, str]] | None] = ContextVar(
- "_vt_active_calls", default=None
- )
- class SourceType(Enum):
- """
- This Enum divides VariableTracker into 2 cases, depending on the variable
- it represents:
- - already existed that Dynamo began tracking while introspection (Existing)
- - is a new variable that is created during Dynamo introspection (New)
- In general, we have these invariants:
- 1. for `VariableTracker` associated with `Existing`, its `source` field must not be None.
- 2. for `VariableTracker` associated with `New`, most of the time its
- `source` field is None, except for cases like side effect codegen for
- `AttributeMutationNew`, during which we generate a
- `LocalSource('tmp...')` for such variable, to facilitate codegen.
- """
- Existing = 0
- New = 1
- class MutationType:
- """
- Base class for Variable.mutation_type. It encodes information about
- 1. The type of mutation Dynamo allows on the variable.
- 2. Whether the value represented by this variable already existed before
- Dynamo tracing.
- """
- def __init__(self, typ: SourceType) -> None:
- # In HigherOrderOperator tracing, we need to distinguish
- # between MutationTypes inside the HigherOrderOperator and
- # ones outside it. For example, it is not safe to mutate
- # `a` in the following example because it was constructed
- # in a different scope.
- #
- # def f(x):
- # a = 1
- # def g(x):
- # nonlocal a
- # a = 2
- # return x
- # return wrap(g, x) + a
- #
- # We use self.scope to distinguish this.
- # scope == 0: The object was an existing variable
- # scope == 1: The object was created while Dynamo
- # was introspecting a function
- # (and no HigherOrderOps were involved)
- # scope >= 2: The object was created through
- # Dynamo introspection of a HigherOrderOp.
- # The exact number corresponds to the level
- # of nested HigherOrderOps.
- if typ is SourceType.Existing:
- self.scope = 0
- elif typ is SourceType.New:
- self.scope = current_scope_id()
- else:
- unimplemented(
- gb_type="Unsupported SourceType",
- context=f"MutationType.__init__ {self} {typ}",
- explanation=f"Dynamo does not support the type `{typ}`",
- hints=[
- "This branch is not supposed to be reachable.",
- *graph_break_hints.DYNAMO_BUG,
- ],
- )
- class ValueMutationNew(MutationType):
- """
- This case of VariableTracker.mutation_type marker indicates
- 1. Dynamo allows mutation on the value itself (rather than its attributes).
- 2. The value is created by the bytecode Dynamo is tracing through.
- For instance, Dynamo could model a newly created list with this marker,
- indicating that while we need to model mutations to this list, we don't have
- to emit bytecode for these mutations if the list doesn't escape into the
- Python world.
- """
- def __init__(self) -> None:
- super().__init__(SourceType.New)
- def __hash__(self) -> int:
- return id(self)
- def __eq__(self, other: object) -> bool:
- return self is other
- class ValueMutationExisting(MutationType):
- """
- This case of VariableTracker.mutation_type marker indicates
- 1. Dynamo allows mutation on the value itself (rather than its attributes).
- 2. The value exists before Dynamo tracing started.
- For instance, Dynamo could model a pre-existing list with this marker,
- indicating that if we encounter mutations to this list, we need to buffer
- and re-apply those mutations after the graph runs, since the list might be
- used afterwards in Python.
- """
- # A flag to indicate whether mutation happened on the associated
- # `VariableTracker`. This enables SideEffects to accurately and quickly
- # filter out which pre-existing values it needs to generate mutation for.
- is_modified: bool
- def __init__(self, is_modified: bool = False) -> None:
- super().__init__(SourceType.Existing)
- self.is_modified = is_modified
- class AttributeMutation(MutationType):
- """
- This case of VariableTracker.mutation_type marker indicates that Dynamo
- allows mutation on the value's attributes.
- """
- class AttributeMutationExisting(AttributeMutation):
- """
- This case of VariableTracker.mutation_type marker indicates
- 1. Dynamo allows mutation on the value's attributes.
- 2. The value exists before Dynamo tracing started.
- For instance, Dynamo could model a pre-existing object with this marker,
- indicating that if we encounter mutations to this object, we need to buffer
- then re-apply those mutations after the graph runs, since the object might
- be used afterwards in Python.
- """
- def __init__(self) -> None:
- super().__init__(SourceType.Existing)
- class AttributeMutationNew(AttributeMutation):
- """
- This case of VariableTracker.mutation_type marker indicates
- 1. Dynamo allows mutation on the value's attributes.
- 2. The value is created by the bytecode Dynamo is tracing through.
- For instance, Dynamo could model a newly created object with this marker,
- indicating that while we need to model mutations to this object, we don't
- have to emit bytecode for these mutations if the object doesn't escape into
- the Python world.
- """
- def __init__(self, cls_source: Optional[Source] = None) -> None:
- super().__init__(SourceType.New)
- self.cls_source = cls_source
- def _is_top_level_scope(scope_id: int) -> bool:
- return scope_id == 1
- def is_side_effect_safe(m: MutationType) -> bool:
- scope_id = current_scope_id()
- # In the top-level scope (if no HigherOrderOperators are involved),
- # we are allowed to modify variables created in this scope as well
- # as existing variables.
- if _is_top_level_scope(scope_id):
- return True
- # Otherwise, only allow local mutation of variables created in the current scope
- return m.scope == scope_id
- # This helps users of `as_python_constant` to catch unimplemented error with
- # more information; it inherits `NotImplementedError` for backward
- # compatibility reasons.
- class AsPythonConstantNotImplementedError(NotImplementedError):
- vt: "VariableTracker"
- def __init__(self, vt: "VariableTracker", msg: str | None = None) -> None:
- msg = f"{vt} is not a constant" if msg is None else msg
- super().__init__(msg)
- self.vt = vt
- class VariableTrackerMeta(type):
- all_subclasses: list[type] = []
- def __new__(
- mcs: type, name: str, bases: tuple[type, ...], attrs: dict[str, Any]
- ) -> type:
- # Determine which metaclass to use based on the class attributes
- # Classes with _no_implicit_realize = True should NOT implicitly realize
- # (they need standard isinstance behavior to avoid infinite recursion)
- # Check if any base class has _no_implicit_realize set, or if it's in attrs
- no_implicit_realize = attrs.get("_no_implicit_realize", False) or any(
- getattr(base, "_no_implicit_realize", False) for base in bases
- )
- if no_implicit_realize or name == "VariableTracker":
- # Use base VariableTrackerMeta (no custom __instancecheck__)
- return super().__new__(VariableTrackerMeta, name, bases, attrs)
- else:
- # Use ImplicitRealizingVariableTrackerMeta for all other subclasses
- return super().__new__(
- ImplicitRealizingVariableTrackerMeta, name, bases, attrs
- )
- def __init__(
- cls: type, name: str, bases: tuple[type, ...], attrs: dict[str, Any]
- ) -> None:
- super().__init__(name, bases, attrs) # type: ignore[misc]
- VariableTrackerMeta.all_subclasses.append(cls)
- class ImplicitRealizingVariableTrackerMeta(VariableTrackerMeta):
- def __instancecheck__(self, instance: object) -> bool:
- """Make isinstance work with LazyVariableTracker"""
- if instancecheck(LazyVariableTracker, instance):
- return instance.lazy_isinstance(self) # pyrefly: ignore[missing-attribute]
- return instancecheck(self, instance)
- class VariableTracker(metaclass=VariableTrackerMeta):
- """
- Base class for tracked locals and stack values
- VariableTracker instances are immutable and should be copied in
- order to change them.
- Prefer the factory function VariableTracker.build() over VariableTracker.__init__().
- """
- # fields to leave unmodified in apply()
- _nonvar_fields = {
- "value",
- "guards",
- "source",
- "mutation_type",
- "parents_tracker",
- "user_code_variable_name",
- }
- def clone(self, **kwargs: Any) -> "VariableTracker":
- """Shallow copy with some (optional) changes"""
- args = dict(self.__dict__)
- args.update(kwargs)
- return self.__class__(**args)
- @classmethod
- def visit(
- cls,
- fn: Callable[["VariableTracker"], None],
- value: Any,
- cache: Optional[dict[int, Any]] = None,
- ) -> None:
- """
- Walk value and call fn on all the VariableTracker instances
- """
- if cache is None:
- cache = {}
- idx = id(value)
- if idx in cache:
- return
- # save `value` to keep it alive and ensure id() isn't reused
- cache[idx] = value
- if isinstance(value, VariableTracker):
- value = value.unwrap()
- fn(value)
- value = value.unwrap() # calling fn() might have realized it
- nonvars = value._nonvar_fields
- for key, subvalue in value.__dict__.items():
- if key not in nonvars:
- cls.visit(fn, subvalue, cache)
- elif istype(value, (list, tuple)):
- for subvalue in value:
- cls.visit(fn, subvalue, cache)
- elif istype(value, (dict, collections.OrderedDict)):
- for subvalue in value.values():
- cls.visit(fn, subvalue, cache)
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}()"
- def debug_repr(self) -> str:
- # Intended to be overridden to provide more info
- try:
- return repr(self.as_python_constant())
- except NotImplementedError:
- return repr(self)
- def python_type(self) -> type:
- """
- Abstract method to be implemented by subclasses of VariableTracker.
- This method should return the type represented by the instance of the subclass.
- The purpose is to provide a standardized way to retrieve the Python type information
- of the variable being tracked.
- Returns:
- type: The Python type (such as int, str, list, etc.) of the variable tracked by
- the subclass. If the type cannot be determined or is not relevant,
- leaving it undefined or invoking super() is always sound.
- Note:
- This is an abstract method and may be overridden in subclasses.
- Example:
- class SetVariable(VariableTracker):
- def python_type(self):
- return set
- Raises:
- NotImplementedError: If the method is not implemented in a subclass.
- """
- try:
- return type(self.as_python_constant())
- except NotImplementedError:
- raise NotImplementedError(f"{self} has no type") from None
- def python_type_name(self) -> str:
- try:
- return self.python_type().__name__
- except NotImplementedError:
- return "<unknown type>"
- def as_python_constant(self) -> Any:
- """For constants"""
- raise AsPythonConstantNotImplementedError(self)
- def guard_as_python_constant(self) -> Any:
- """Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants"""
- try:
- return self.as_python_constant()
- except NotImplementedError:
- unimplemented(
- gb_type="Not a Python constant",
- context=f"guard_as_python_constant {self}",
- explanation=f"Failed to convert {self} into a Python constant.",
- hints=[],
- )
- def is_python_constant(self) -> bool:
- try:
- self.as_python_constant()
- return True
- except NotImplementedError:
- return False
- def is_constant_match(self, *values: Any) -> bool:
- """
- Check if this variable is a python constant matching one of the given values.
- Examples:
- var.is_constant_match(None) # True if var is constant None
- var.is_constant_match(True, False) # True if var is constant True or False
- var.is_constant_match(NotImplemented) # True if var is constant NotImplemented
- """
- return False
- def is_constant_none(self) -> bool:
- """Check if this variable is a constant None value."""
- return False
- def make_guard(self, fn: Callable[..., Any]) -> Guard:
- if self.source:
- return self.source.make_guard(fn)
- raise NotImplementedError
- # TODO[@lucaskabela] - change this type to `InstructionTranslatorBase`
- # and cascade that (large blast radius)
- def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any:
- """getattr(self, name) returning a python constant"""
- raise NotImplementedError
- def is_symnode_like(self) -> bool:
- """Return True for values that can participate in SymNode operations"""
- return False
- def is_tensor(self) -> bool:
- """Return True for TensorVariable instances"""
- return False
- def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
- """getattr(self, name) returning a new variable"""
- value = self.const_getattr(tx, name)
- if not variables.ConstantVariable.is_literal(value):
- raise NotImplementedError
- source = self.source and AttrSource(self.source, name)
- if source and not self.is_python_constant():
- # The second condition is to avoid guards on const getattr objects
- # like __code__.co_argcount
- install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH))
- return variables.ConstantVariable.create(value, source=source)
- def is_proxy(self) -> bool:
- try:
- self.as_proxy()
- return True
- except NotImplementedError:
- return False
- def as_proxy(self) -> Any:
- raise NotImplementedError(str(self))
- def maybe_fx_node(self) -> Optional[Node]:
- try:
- proxy = self.as_proxy()
- import torch.fx
- if isinstance(proxy, torch.fx.Proxy):
- return proxy.node
- return None
- except NotImplementedError:
- return None
- def _contains_self_reference(self) -> bool:
- """Check if this variable references itself (directly or indirectly)."""
- found_self = False
- def check(vt: "VariableTracker") -> None:
- nonlocal found_self
- if vt is self:
- found_self = True
- # unwrap first iteration - otherwise we can't detect if we revisit self
- for key, subvalue in self.__dict__.items():
- if key not in self._nonvar_fields:
- VariableTracker.visit(check, subvalue)
- return found_self
- def reconstruct(self, codegen: "PyCodegen") -> None:
- raise NotImplementedError
- def unpack_var_sequence(self, tx: Any) -> list["VariableTracker"]:
- raise NotImplementedError
- def force_unpack_var_sequence(self, tx: Any) -> list["VariableTracker"]:
- # like unpack_var_sequence, but should only be used when it is
- # safe to eagerly (vs. lazily) unpack this variable.
- # e.g. map(f, x) is normally evaluated lazily but sometimes
- # we want to force eager unpacking, e.g. when converting to a list.
- # NOTE: this method is allowed to mutate the VariableTracker, so
- # it should only be called once.
- return self.unpack_var_sequence(tx)
- def has_unpack_var_sequence(self, tx: Any) -> bool:
- try:
- self.unpack_var_sequence(tx)
- return True
- except NotImplementedError:
- return False
- # NB: don't call force_unpack_var_sequence, especially if it mutates!
- def has_force_unpack_var_sequence(self, tx: Any) -> bool:
- return self.has_unpack_var_sequence(tx)
- # Forces unpacking the var sequence while also applying a function to each element.
- # Only use when it is safe to eagerly unpack this variable (like force_unpack_var_sequence).
- # INVARIANT: variable must satisfy has_force_unpack_var_sequence() == True!
- def force_apply_to_var_sequence(
- self, tx: Any, fn: Callable[["VariableTracker"], Any]
- ) -> None:
- assert self.has_force_unpack_var_sequence(tx)
- for v in self.unpack_var_sequence(tx):
- fn(v)
- def call_obj_hasattr(
- self, tx: "InstructionTranslator", name: str
- ) -> "ConstantVariable":
- unimplemented(
- gb_type="Unsupported hasattr call",
- context=f"call_obj_hasattr {self} {name}",
- explanation=f"Dynamo does not know how to trace the function `{self.debug_repr()}`",
- hints=[
- f"Avoid calling `hasattr({self.__class__.__name__}, {name})` in your code.",
- *graph_break_hints.SUPPORTABLE,
- ],
- )
- def call_function(
- self,
- tx: Any,
- args: Sequence["VariableTracker"],
- kwargs: dict[str, "VariableTracker"],
- ) -> "VariableTracker":
- unimplemented(
- gb_type="Unsupported function call",
- context=f"call_function {self} {args} {kwargs}",
- explanation=f"Dynamo does not know how to trace the function `{self.debug_repr()}`",
- hints=[
- f"Avoid calling `{self.debug_repr()}` in your code.",
- "Please report an issue to PyTorch.",
- ],
- )
- def call_method(
- self,
- tx: Any,
- name: str,
- args: list["VariableTracker"],
- kwargs: dict[str, "VariableTracker"],
- ) -> "VariableTracker":
- if name == "__len__" and self.has_unpack_var_sequence(tx):
- assert not (args or kwargs)
- return variables.ConstantVariable.create(len(self.unpack_var_sequence(tx)))
- elif (
- name == "__getattr__"
- and len(args) == 1
- and args[0].is_python_constant()
- and not kwargs
- ):
- return self.var_getattr(tx, args[0].as_python_constant())
- elif name in cmp_name_to_op_mapping and len(args) == 1 and not kwargs:
- other = args[0]
- if not isinstance(self, type(other)) and not (
- isinstance(self, variables.GetAttrVariable)
- or isinstance(other, variables.GetAttrVariable)
- ):
- # NB: GetAttrVariable is a special case because sometimes an
- # object can map to GetAttrVariable but other time as
- # SkipFunctionVariable if it is an input to the compiled
- # function, e.g. tensor.data_ptr
- return variables.ConstantVariable.create(NotImplemented)
- # NB : Checking for mutation is necessary because we compare
- # constant values
- if (
- not self.is_python_constant()
- or not other.is_python_constant()
- or tx.output.side_effects.has_pending_mutation(self)
- or tx.output.side_effects.has_pending_mutation(other)
- ):
- unimplemented(
- gb_type="Builtin `operator.*` comparison with constant `self` failed",
- context=f"call_method {self} {name} {args} {kwargs}",
- explanation=f"Failed to compare {self} with {other}, "
- + f"because {other} is not a Python constant or its mutation check fails.",
- hints=[],
- )
- try:
- return variables.ConstantVariable.create(
- cmp_name_to_op_mapping[name](
- self.as_python_constant(), other.as_python_constant()
- )
- )
- except Exception as e:
- raise_observed_exception(
- type(e),
- tx,
- args=[list(map(variables.ConstantVariable.create, e.args))],
- )
- hints = [
- f"Avoid calling `{self.python_type_name()}.{name}` in your code.",
- "Please report an issue to PyTorch.",
- ]
- # additional hint for method calls on improperly constructed iterators
- if isinstance(self, variables.UserDefinedObjectVariable) and name in (
- "__iter__",
- "__next__",
- ):
- if isinstance(self.value, (KeysView, ItemsView, ValuesView)):
- hints.append(
- "Consider moving the creation of dict view object (e.g. `dict.keys()`, `dict.items()`,) "
- "to the compiled region, instead of passing it as an input to the compiled region."
- )
- hints.append(
- "Dynamo does not fully support tracing builtin iterators (e.g. `map`, `zip`, `enumerate`) "
- "passed in from uncompiled to compiled regions (e.g. `torch.compile(fn)(enumerate(...))`). "
- "This can happen unintentionally if a previous graph break happens with a builtin iterator "
- "in the local scope."
- )
- hints.append(
- "List/dict comprehensions in Python <= 3.11 result in implicit function calls, which Dynamo "
- "cannot trace as a top level frame. Possible workarounds are (1) use a loop instead of a comprehension, "
- "(2) fix any graph breaks in the function above the comprehension, (3) wrap the comprehension in a "
- "function, or (4) use Python 3.12+."
- )
- unimplemented(
- gb_type="Unsupported method call",
- context=f"call_method {self} {name} {args} {kwargs}",
- explanation=f"Dynamo does not know how to trace method `{name}` of class `{self.python_type_name()}`",
- hints=hints,
- )
- def call_tree_map(
- self,
- tx: Any,
- tree_map_fn: "UserFunctionVariable",
- map_fn: "VariableTracker",
- rest: Sequence["VariableTracker"],
- tree_map_kwargs: dict[str, "VariableTracker"],
- ) -> "VariableTracker":
- """Performance optimization to implement optree.tree_map faster than tracing it"""
- is_leaf_var = tree_map_kwargs.get("is_leaf")
- if is_leaf_var is not None and not is_leaf_var.is_constant_none():
- pred_result = is_leaf_var.call_function(tx, [self], {})
- try:
- leaf_decision = pred_result.as_python_constant()
- except NotImplementedError:
- return self._tree_map_fallback(
- tx,
- tree_map_fn,
- map_fn,
- rest,
- tree_map_kwargs,
- )
- if leaf_decision:
- return map_fn.call_function(tx, [self, *rest], {})
- return self.call_tree_map_branch(
- tx,
- tree_map_fn,
- map_fn,
- rest,
- tree_map_kwargs,
- )
- def call_tree_map_branch(
- self,
- tx: Any,
- tree_map_fn: "UserFunctionVariable",
- map_fn: "VariableTracker",
- rest: Sequence["VariableTracker"],
- tree_map_kwargs: dict[str, "VariableTracker"],
- ) -> "VariableTracker":
- """Emulate optree.tree_map without is_leaf/none_is_leaf checks (handled above)"""
- return self._tree_map_fallback(
- tx,
- tree_map_fn,
- map_fn,
- rest,
- tree_map_kwargs,
- )
- def _tree_map_fallback(
- self,
- tx: Any,
- tree_map_fn: "UserFunctionVariable",
- map_fn: "VariableTracker",
- rest: Sequence["VariableTracker"],
- tree_map_kwargs: dict[str, "VariableTracker"],
- ) -> "VariableTracker":
- tree_map_fn_copy = tree_map_fn.clone()
- tree_map_fn_copy._maybe_call_tree_map_fastpath = lambda *args, **kwargs: None # type: ignore[missing-attribute]
- log.debug(
- "tree_map fastpath fallback triggered for %s (rest=%s, kwargs=%s)",
- self,
- rest,
- tree_map_kwargs,
- )
- return tree_map_fn_copy.call_function(
- tx,
- [map_fn, self, *rest],
- tree_map_kwargs,
- )
- def set_name_hint(self, name: str) -> None:
- pass
- def realize(self) -> "VariableTracker":
- """Used by LazyVariableTracker to build the real VariableTracker"""
- return self
- def unwrap(self) -> "VariableTracker":
- """Used by LazyVariableTracker to return the real VariableTracker if it already exists"""
- return self
- def is_realized(self) -> bool:
- """Used by LazyVariableTracker to indicate an unrealized node"""
- return True
- def next_variable(self, tx: Any) -> "VariableTracker":
- unimplemented(
- gb_type="Unsupported next() call",
- context=f"next({self})",
- explanation=f"Dynamo does not know how to trace calling `next()` on variable `{self}`.",
- hints=[*graph_break_hints.USER_ERROR],
- )
- def is_strict_mode(self, tx: Any) -> bool:
- return bool(tx.strict_checks_fn and tx.strict_checks_fn(self))
- def is_mutable(self) -> bool:
- """Whether Dynamo allows mutation on this variable."""
- return not self.is_immutable()
- def is_immutable(self) -> bool:
- """Whether Dynamo bans mutation on this variable."""
- return self.mutation_type is None
- @staticmethod
- def build(
- tx: Any,
- value: Any,
- source: Optional[Source] = None,
- realize: bool = False,
- ) -> Any:
- """Create a new VariableTracker from a value and optional Source"""
- if source is None:
- return builder.SourcelessBuilder.create(tx, value)
- elif realize:
- return builder.VariableBuilder(tx, source)(value)
- elif type(value) in variables.LazyConstantVariable.supported_types:
- # Use LazyConstantVariable for primitives to enable deferred
- # guard installation - constants that are just passed through
- # won't cause recompilation when their values change.
- return variables.LazyConstantVariable.create(value, source)
- else:
- return variables.LazyVariableTracker.create(value, source)
- def is_python_hashable(self) -> bool:
- """
- Unlike the variable tracker's own __hash__, this method checks whether
- the underlying Python object referenced by this variable tracker is hashable.
- """
- try:
- type_self = self.python_type()
- except NotImplementedError:
- type_self = type(self)
- unimplemented(
- gb_type="Dynamo cannot determine whether the underlying object is hashable",
- context=f"is_python_hashable {self}",
- explanation=f"Dynamo does not know whether the underlying python object for {self} is hashable",
- hints=[
- (
- f"Consider using a different type of object as the dictionary key instead of {type_self}."
- ),
- *graph_break_hints.SUPPORTABLE,
- ],
- )
- def get_python_hash(self) -> int:
- """
- Unlike the variable tracker’s own __hash__, this method is used by
- ConstDictVariableTracker to compute the hash of the underlying key object.
- """
- unimplemented(
- gb_type="Dynamo cannot determine the hash of an object",
- context=f"get_python_hash {self}",
- explanation=f"Dynamo does not know the hash of the underlying python object for {self}",
- hints=[
- (
- f"Consider using a different type of object as the dictionary key instead of {self.python_type()}."
- ),
- *graph_break_hints.SUPPORTABLE,
- ],
- )
- def is_python_equal(self, other: object) -> bool:
- """
- NB - Deliberately not overriding the __eq__ method because that can
- disable the __hash__ for the vt itself.
- """
- unimplemented(
- gb_type="Dynamo cannot determine the equality comparison of an object",
- context=f"is_python_equal {self}",
- explanation=f"Dynamo does not know the equality comparison of the underlying python object for {self}",
- hints=[
- (
- f"Consider using a different type of object as the dictionary key instead of {self.python_type()}."
- ),
- *graph_break_hints.SUPPORTABLE,
- ],
- )
- def __init__(
- self,
- *,
- source: Optional[Source] = None,
- mutation_type: Optional[MutationType] = None,
- ) -> None:
- super().__init__()
- self.source = source
- self.mutation_type = mutation_type
- # NOTE sometimes mutation_type is set afterwards for implementation
- # convenience, we don't validate those cases at the moment.
- if mutation_type is not None:
- if isinstance(mutation_type, (ValueMutationNew, AttributeMutationNew)):
- # If this fails, it's either
- # 1. one mistakenly passed in a source
- # 2. `mutation_type` is incorrect
- assert source is None
- else:
- assert isinstance(
- mutation_type, (ValueMutationExisting, AttributeMutationExisting)
- )
- # If this fails, it's either
- # 1. one forgot to pass in a source
- # 2. `mutation_type` is incorrect
- assert source is not None
- def __init_subclass__(cls, **kwargs: Any) -> None:
- """
- Wraps all subclasses' `as_python_constant` and `reconstruct` so that it cannot be
- called twice in the same call chain - i.e. self-referential objects.
- For `as_python_constant` - self-referential objects are NOT treated as constants.
- For `reconstruct` - we will graph break. The graph break can be avoided if the VT subclass
- can generate and cache itself before recursively `reconstruct`ing - see ListVariable for an example.
- """
- super().__init_subclass__(**kwargs)
- def as_python_constant_failure(self) -> NoReturn:
- raise AsPythonConstantNotImplementedError(
- self, msg=f"{self} is self-referential"
- )
- VariableTracker._add_call_once_guard(
- cls, "as_python_constant", as_python_constant_failure
- )
- def reconstruct_failure(self) -> NoReturn:
- unimplemented(
- gb_type="Reconstruction failure (self-referential)",
- context=str(self),
- explanation=f"Dynamo tried to reconstruct sourceless variable {self}, but it is self-referential. "
- "Dynamo must manually implement reconstruction rules for self-referentiable sourceless variables.",
- hints=[
- "If Dynamo is attempting to trace a return statement and your code is attempting to return a variable "
- "that Dynamo cannot reconstruct, then remove it from the return statement.",
- "Remove the self-reference in the variable. A self-referring list, for example, is `l = []; l.append(l)`.",
- *graph_break_hints.CAUSED_BY_EARLIER_GRAPH_BREAK,
- "Report an issue to PyTorch if you need self-referential reconstrtuction support.",
- ],
- )
- VariableTracker._add_call_once_guard(cls, "reconstruct", reconstruct_failure)
- @staticmethod
- def _add_call_once_guard(
- cls: type["VariableTracker"],
- method: str,
- callback: Callable[["VariableTracker"], Any],
- ) -> None:
- original_method = getattr(cls, method)
- if original_method is getattr(VariableTracker, method) or hasattr(
- original_method, "_call_once_guarded"
- ):
- return
- @functools.wraps(original_method)
- def guarded_method(self, *args: Any, **kwargs: Any) -> VariableTracker:
- active = _vt_active_calls.get()
- if active is None:
- active = set()
- _vt_active_calls.set(active)
- key = (id(self), method)
- if key in active:
- callback(self)
- active.add(key)
- try:
- return original_method(self, *args, **kwargs)
- finally:
- active.discard(key)
- guarded_method._call_once_guarded = True # pyrefly: ignore[missing-attribute]
- setattr(cls, method, guarded_method)
- def raise_type_error_exc(tx: Any, msg_str: str) -> NoReturn:
- msg = variables.ConstantVariable.create(msg_str)
- raise_observed_exception(TypeError, tx, args=[msg])
- def typestr(*objs: object) -> str:
- if len(objs) == 1:
- (obj,) = objs
- if isinstance(obj, VariableTracker):
- return str(obj)
- else:
- return type(obj).__name__
- else:
- return " ".join(map(typestr, objs))
- instancecheck = type.__instancecheck__
- from . import builder
- from .lazy import LazyVariableTracker
|