lazy.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. from __future__ import annotations
  2. import collections
  3. import functools
  4. import inspect
  5. from typing import Any, TYPE_CHECKING
  6. from ..utils import is_function_or_wrapper
  7. from .base import VariableTracker, VariableTrackerMeta
  8. if TYPE_CHECKING:
  9. from collections.abc import Callable
  10. from typing_extensions import Self
  11. from .tensor import SymNodeVariable
  12. class LazyCache:
  13. """Container to cache the real VariableTracker"""
  14. def __init__(self, value: Any, source: Any) -> None:
  15. if not isinstance(value, LazySymNodeFormatString):
  16. assert source
  17. self.value = value
  18. self.source = source
  19. self.name_hint: str | None = None
  20. self.vt: VariableTracker | None = None
  21. def realize(self) -> None:
  22. assert self.vt is None
  23. from ..symbolic_convert import InstructionTranslator
  24. from . import builder
  25. tx = InstructionTranslator.current_tx()
  26. if isinstance(self.value, LazySymNodeFormatString):
  27. self.vt = builder.SourcelessBuilder.create(tx, self.value)
  28. else:
  29. # Pass allow_lazy_constant=False to prevent VariableBuilder from
  30. # returning LazyConstantVariable, which would cause infinite recursion
  31. # when LazyVariableTracker.realize() returns LazyConstantVariable.
  32. self.vt = builder.VariableBuilder(
  33. tx, self.source, allow_lazy_constant=False
  34. )(self.value)
  35. if self.name_hint is not None:
  36. self.vt.set_name_hint(self.name_hint)
  37. del self.value
  38. del self.source
  39. del self.name_hint
  40. class LazyVariableTracker(VariableTracker, metaclass=VariableTrackerMeta):
  41. """
  42. A structure that defers the creation of the actual VariableTracker
  43. for a given underlying value until it is accessed.
  44. The `realize` function invokes VariableTracker.build() to produce the real object.
  45. Once a LazyVariableTracker has been realized, internal bookkeeping will
  46. prevent double realization.
  47. This object should be utilized for processing containers, or objects that
  48. reference other objects where we may not want to take on creating all the
  49. VariableTrackers right away.
  50. """
  51. # Flag to prevent implicit realization in isinstance checks (inherited by subclasses)
  52. _no_implicit_realize = True
  53. _nonvar_fields = {"_cache", *VariableTracker._nonvar_fields}
  54. @staticmethod
  55. def create(value: Any, source: Any, **options: Any) -> VariableTracker:
  56. if type(value) in LazyConstantVariable.supported_types:
  57. return LazyConstantVariable.create(value, source, **options)
  58. # Cache based on source when no extra options are passed
  59. if source is not None and not options:
  60. from ..symbolic_convert import InstructionTranslator
  61. tx = InstructionTranslator.current_tx()
  62. if tx is not None:
  63. cache = tx.output.variable_tracker_cache
  64. cached = cache.get(source)
  65. if cached is not None:
  66. return cached
  67. vt = LazyVariableTracker(LazyCache(value, source), source=source)
  68. cache[source] = vt
  69. return vt
  70. return LazyVariableTracker(LazyCache(value, source), source=source, **options)
  71. def __init__(self, _cache: LazyCache, **kwargs: Any) -> None:
  72. assert isinstance(_cache, LazyCache)
  73. super().__init__(**kwargs)
  74. self._cache = _cache
  75. def realize(self) -> VariableTracker:
  76. """Force construction of the real VariableTracker"""
  77. if self._cache.vt is None:
  78. self._cache.realize()
  79. assert self._cache.vt is not None
  80. return self._cache.vt
  81. def lazy_isinstance(self, cls: type) -> bool:
  82. """Check isinstance after realizing, used by ImplicitRealizingVariableTrackerMeta"""
  83. return type.__instancecheck__(cls, self.realize())
  84. def unwrap(self) -> VariableTracker | Self:
  85. """Return the real VariableTracker if it already exists"""
  86. if self.is_realized():
  87. assert self._cache.vt is not None
  88. return self._cache.vt
  89. return self
  90. def is_realized(self) -> bool:
  91. return self._cache.vt is not None
  92. def clone(self, **kwargs: Any) -> VariableTracker:
  93. assert kwargs.get("_cache", self._cache) is self._cache
  94. if kwargs.get("source", self.source) is not self.source:
  95. self.realize()
  96. return VariableTracker.clone(self.unwrap(), **kwargs)
  97. def peek_type(self) -> type[Any]:
  98. assert not self.is_realized()
  99. return type(self._cache.value)
  100. def peek_value(self) -> Any:
  101. assert not self.is_realized()
  102. return self._cache.value
  103. def set_name_hint(self, name: str) -> None:
  104. if self.is_realized():
  105. self._cache.vt.set_name_hint(name) # type: ignore[union-attr]
  106. else:
  107. self._cache.name_hint = name
  108. def __str__(self) -> str:
  109. variable_info = "LazyVariableTracker("
  110. if self.is_realized():
  111. variable_info += f"realized: {repr(self.unwrap())})"
  112. else:
  113. variable_info += f"unrealized: {self.peek_type()})"
  114. return variable_info
  115. def __getattr__(self, item: str) -> Any:
  116. return getattr(self.realize(), item)
  117. # most methods are auto-generated below, these are the ones we want to exclude
  118. visit = VariableTracker.visit # type: ignore[assignment]
  119. __repr__ = __str__
  120. @classmethod
  121. def realize_all(
  122. cls,
  123. value: Any,
  124. cache: dict[int, tuple[Any, Any]] | None = None,
  125. *,
  126. allow_lazy_constant: bool = False,
  127. ) -> Any:
  128. """
  129. Walk an object and realize all LazyVariableTrackers inside it.
  130. """
  131. if cache is None:
  132. cache = {}
  133. idx = id(value)
  134. if idx in cache:
  135. return cache[idx][0]
  136. value_cls = type(value)
  137. if issubclass(value_cls, LazyVariableTracker):
  138. # Allow LazyConstantVariable to stay lazy when returning from a frame
  139. keep_lazy = allow_lazy_constant and isinstance(value, LazyConstantVariable)
  140. if keep_lazy:
  141. result = value
  142. else:
  143. result = cls.realize_all(
  144. value.realize(), cache, allow_lazy_constant=allow_lazy_constant
  145. )
  146. elif issubclass(value_cls, VariableTracker):
  147. # update value in-place
  148. result = value
  149. # update cache now to prevent infinite recursion
  150. cache[idx] = (result, value)
  151. value_dict = value.__dict__
  152. nonvars = value._nonvar_fields
  153. for key in value_dict:
  154. if key not in nonvars:
  155. value_dict[key] = cls.realize_all(
  156. value_dict[key], cache, allow_lazy_constant=allow_lazy_constant
  157. )
  158. elif value_cls is list:
  159. result = [
  160. cls.realize_all(v, cache, allow_lazy_constant=allow_lazy_constant)
  161. for v in value
  162. ]
  163. elif value_cls is tuple:
  164. result = tuple(
  165. cls.realize_all(v, cache, allow_lazy_constant=allow_lazy_constant)
  166. for v in value
  167. )
  168. elif value_cls in (dict, collections.OrderedDict):
  169. result = {
  170. k: cls.realize_all(v, cache, allow_lazy_constant=allow_lazy_constant)
  171. for k, v in list(value.items())
  172. }
  173. else:
  174. result = value
  175. # save `value` to keep it alive and ensure id() isn't reused
  176. cache[idx] = (result, value)
  177. return result
  178. def is_hashable(self) -> bool:
  179. # Checks that the underlying value is hashable without realizing the VT.
  180. # This is used by ConstDictVariable tracker to find if the key LazyVT
  181. # can be hashed.
  182. def _helper(value: Any) -> bool:
  183. # TODO: Add support for more types
  184. return (
  185. inspect.isbuiltin(value)
  186. or issubclass(type(value), type)
  187. or is_function_or_wrapper(value)
  188. )
  189. assert not self.is_realized()
  190. value = self._cache.value
  191. if isinstance(value, tuple):
  192. return all(_helper(v) for v in value)
  193. return _helper(value)
  194. def original_value(self) -> Any:
  195. # Returns the value without realizing the VT.
  196. assert not self.is_realized()
  197. return self._cache.value
  198. def original_source(self) -> Any:
  199. # Returns the source without realizing the VT.
  200. assert not self.is_realized()
  201. return self._cache.source
  202. class LazyConstantVariable(LazyVariableTracker):
  203. """
  204. A lazy variable tracker for constants (int, float, bool, str) that defers
  205. guarding until the value is actually used in a way that requires it.
  206. This allows constants that are just passed through (e.g., returned without
  207. being used in control flow or math) to avoid unnecessary recompilation when
  208. their values change.
  209. """
  210. supported_types = (int, float, bool, str)
  211. @staticmethod
  212. def create( # pyrefly: ignore[bad-override]
  213. value: Any,
  214. source: Any,
  215. **options: Any,
  216. ) -> LazyConstantVariable:
  217. assert type(value) in LazyConstantVariable.supported_types
  218. return LazyConstantVariable(LazyCache(value, source), source=source, **options)
  219. class LazySymNodeFormatString:
  220. def __init__(
  221. self, sym_node_variable: SymNodeVariable, fmt_spec_var: VariableTracker
  222. ) -> None:
  223. from .constant import ConstantVariable
  224. self.sym_node_var = sym_node_variable
  225. self.fmt_var = ConstantVariable.create(
  226. "{:" + fmt_spec_var.as_python_constant() + "}"
  227. )
  228. def __repr__(self) -> str:
  229. return str.format(
  230. self.fmt_var.as_python_constant(),
  231. str(self.sym_node_var.evaluate_expr()),
  232. )
  233. def _create_realize_and_forward(
  234. name: str,
  235. ) -> Callable[[LazyVariableTracker, Any, Any], Any]:
  236. @functools.wraps(getattr(VariableTracker, name))
  237. def realize_and_forward(
  238. self: LazyVariableTracker, *args: Any, **kwargs: Any
  239. ) -> Any:
  240. return getattr(self.realize(), name)(*args, **kwargs)
  241. return realize_and_forward
  242. def _populate() -> None:
  243. for name, value in VariableTracker.__dict__.items():
  244. if name not in LazyVariableTracker.__dict__:
  245. if callable(value):
  246. setattr(LazyVariableTracker, name, _create_realize_and_forward(name))
  247. _populate()