| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306 |
- from __future__ import annotations
- import collections
- import functools
- import inspect
- from typing import Any, TYPE_CHECKING
- from ..utils import is_function_or_wrapper
- from .base import VariableTracker, VariableTrackerMeta
- if TYPE_CHECKING:
- from collections.abc import Callable
- from typing_extensions import Self
- from .tensor import SymNodeVariable
- class LazyCache:
- """Container to cache the real VariableTracker"""
- def __init__(self, value: Any, source: Any) -> None:
- if not isinstance(value, LazySymNodeFormatString):
- assert source
- self.value = value
- self.source = source
- self.name_hint: str | None = None
- self.vt: VariableTracker | None = None
- def realize(self) -> None:
- assert self.vt is None
- from ..symbolic_convert import InstructionTranslator
- from . import builder
- tx = InstructionTranslator.current_tx()
- if isinstance(self.value, LazySymNodeFormatString):
- self.vt = builder.SourcelessBuilder.create(tx, self.value)
- else:
- # Pass allow_lazy_constant=False to prevent VariableBuilder from
- # returning LazyConstantVariable, which would cause infinite recursion
- # when LazyVariableTracker.realize() returns LazyConstantVariable.
- self.vt = builder.VariableBuilder(
- tx, self.source, allow_lazy_constant=False
- )(self.value)
- if self.name_hint is not None:
- self.vt.set_name_hint(self.name_hint)
- del self.value
- del self.source
- del self.name_hint
- class LazyVariableTracker(VariableTracker, metaclass=VariableTrackerMeta):
- """
- A structure that defers the creation of the actual VariableTracker
- for a given underlying value until it is accessed.
- The `realize` function invokes VariableTracker.build() to produce the real object.
- Once a LazyVariableTracker has been realized, internal bookkeeping will
- prevent double realization.
- This object should be utilized for processing containers, or objects that
- reference other objects where we may not want to take on creating all the
- VariableTrackers right away.
- """
- # Flag to prevent implicit realization in isinstance checks (inherited by subclasses)
- _no_implicit_realize = True
- _nonvar_fields = {"_cache", *VariableTracker._nonvar_fields}
- @staticmethod
- def create(value: Any, source: Any, **options: Any) -> VariableTracker:
- if type(value) in LazyConstantVariable.supported_types:
- return LazyConstantVariable.create(value, source, **options)
- # Cache based on source when no extra options are passed
- if source is not None and not options:
- from ..symbolic_convert import InstructionTranslator
- tx = InstructionTranslator.current_tx()
- if tx is not None:
- cache = tx.output.variable_tracker_cache
- cached = cache.get(source)
- if cached is not None:
- return cached
- vt = LazyVariableTracker(LazyCache(value, source), source=source)
- cache[source] = vt
- return vt
- return LazyVariableTracker(LazyCache(value, source), source=source, **options)
- def __init__(self, _cache: LazyCache, **kwargs: Any) -> None:
- assert isinstance(_cache, LazyCache)
- super().__init__(**kwargs)
- self._cache = _cache
- def realize(self) -> VariableTracker:
- """Force construction of the real VariableTracker"""
- if self._cache.vt is None:
- self._cache.realize()
- assert self._cache.vt is not None
- return self._cache.vt
- def lazy_isinstance(self, cls: type) -> bool:
- """Check isinstance after realizing, used by ImplicitRealizingVariableTrackerMeta"""
- return type.__instancecheck__(cls, self.realize())
- def unwrap(self) -> VariableTracker | Self:
- """Return the real VariableTracker if it already exists"""
- if self.is_realized():
- assert self._cache.vt is not None
- return self._cache.vt
- return self
- def is_realized(self) -> bool:
- return self._cache.vt is not None
- def clone(self, **kwargs: Any) -> VariableTracker:
- assert kwargs.get("_cache", self._cache) is self._cache
- if kwargs.get("source", self.source) is not self.source:
- self.realize()
- return VariableTracker.clone(self.unwrap(), **kwargs)
- def peek_type(self) -> type[Any]:
- assert not self.is_realized()
- return type(self._cache.value)
- def peek_value(self) -> Any:
- assert not self.is_realized()
- return self._cache.value
- def set_name_hint(self, name: str) -> None:
- if self.is_realized():
- self._cache.vt.set_name_hint(name) # type: ignore[union-attr]
- else:
- self._cache.name_hint = name
- def __str__(self) -> str:
- variable_info = "LazyVariableTracker("
- if self.is_realized():
- variable_info += f"realized: {repr(self.unwrap())})"
- else:
- variable_info += f"unrealized: {self.peek_type()})"
- return variable_info
- def __getattr__(self, item: str) -> Any:
- return getattr(self.realize(), item)
- # most methods are auto-generated below, these are the ones we want to exclude
- visit = VariableTracker.visit # type: ignore[assignment]
- __repr__ = __str__
- @classmethod
- def realize_all(
- cls,
- value: Any,
- cache: dict[int, tuple[Any, Any]] | None = None,
- *,
- allow_lazy_constant: bool = False,
- ) -> Any:
- """
- Walk an object and realize all LazyVariableTrackers inside it.
- """
- if cache is None:
- cache = {}
- idx = id(value)
- if idx in cache:
- return cache[idx][0]
- value_cls = type(value)
- if issubclass(value_cls, LazyVariableTracker):
- # Allow LazyConstantVariable to stay lazy when returning from a frame
- keep_lazy = allow_lazy_constant and isinstance(value, LazyConstantVariable)
- if keep_lazy:
- result = value
- else:
- result = cls.realize_all(
- value.realize(), cache, allow_lazy_constant=allow_lazy_constant
- )
- elif issubclass(value_cls, VariableTracker):
- # update value in-place
- result = value
- # update cache now to prevent infinite recursion
- cache[idx] = (result, value)
- value_dict = value.__dict__
- nonvars = value._nonvar_fields
- for key in value_dict:
- if key not in nonvars:
- value_dict[key] = cls.realize_all(
- value_dict[key], cache, allow_lazy_constant=allow_lazy_constant
- )
- elif value_cls is list:
- result = [
- cls.realize_all(v, cache, allow_lazy_constant=allow_lazy_constant)
- for v in value
- ]
- elif value_cls is tuple:
- result = tuple(
- cls.realize_all(v, cache, allow_lazy_constant=allow_lazy_constant)
- for v in value
- )
- elif value_cls in (dict, collections.OrderedDict):
- result = {
- k: cls.realize_all(v, cache, allow_lazy_constant=allow_lazy_constant)
- for k, v in list(value.items())
- }
- else:
- result = value
- # save `value` to keep it alive and ensure id() isn't reused
- cache[idx] = (result, value)
- return result
- def is_hashable(self) -> bool:
- # Checks that the underlying value is hashable without realizing the VT.
- # This is used by ConstDictVariable tracker to find if the key LazyVT
- # can be hashed.
- def _helper(value: Any) -> bool:
- # TODO: Add support for more types
- return (
- inspect.isbuiltin(value)
- or issubclass(type(value), type)
- or is_function_or_wrapper(value)
- )
- assert not self.is_realized()
- value = self._cache.value
- if isinstance(value, tuple):
- return all(_helper(v) for v in value)
- return _helper(value)
- def original_value(self) -> Any:
- # Returns the value without realizing the VT.
- assert not self.is_realized()
- return self._cache.value
- def original_source(self) -> Any:
- # Returns the source without realizing the VT.
- assert not self.is_realized()
- return self._cache.source
- class LazyConstantVariable(LazyVariableTracker):
- """
- A lazy variable tracker for constants (int, float, bool, str) that defers
- guarding until the value is actually used in a way that requires it.
- This allows constants that are just passed through (e.g., returned without
- being used in control flow or math) to avoid unnecessary recompilation when
- their values change.
- """
- supported_types = (int, float, bool, str)
- @staticmethod
- def create( # pyrefly: ignore[bad-override]
- value: Any,
- source: Any,
- **options: Any,
- ) -> LazyConstantVariable:
- assert type(value) in LazyConstantVariable.supported_types
- return LazyConstantVariable(LazyCache(value, source), source=source, **options)
- class LazySymNodeFormatString:
- def __init__(
- self, sym_node_variable: SymNodeVariable, fmt_spec_var: VariableTracker
- ) -> None:
- from .constant import ConstantVariable
- self.sym_node_var = sym_node_variable
- self.fmt_var = ConstantVariable.create(
- "{:" + fmt_spec_var.as_python_constant() + "}"
- )
- def __repr__(self) -> str:
- return str.format(
- self.fmt_var.as_python_constant(),
- str(self.sym_node_var.evaluate_expr()),
- )
- def _create_realize_and_forward(
- name: str,
- ) -> Callable[[LazyVariableTracker, Any, Any], Any]:
- @functools.wraps(getattr(VariableTracker, name))
- def realize_and_forward(
- self: LazyVariableTracker, *args: Any, **kwargs: Any
- ) -> Any:
- return getattr(self.realize(), name)(*args, **kwargs)
- return realize_and_forward
- def _populate() -> None:
- for name, value in VariableTracker.__dict__.items():
- if name not in LazyVariableTracker.__dict__:
- if callable(value):
- setattr(LazyVariableTracker, name, _create_realize_and_forward(name))
- _populate()
|