constant.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. """
  2. Constant and enum variable tracking in Dynamo.
  3. This module is fundamental to Dynamo's ability to track and propagate constant
  4. values during compilation, ensuring proper handling of Python literals and
  5. maintaining type safety through the compilation process.
  6. """
  7. import enum
  8. import operator
  9. from collections.abc import Sequence
  10. from typing import Any, Literal, Optional, overload, TYPE_CHECKING, Union
  11. from typing_extensions import Never, override
  12. import torch
  13. from torch._dynamo.source import AttrSource, GetItemSource
  14. from .. import graph_break_hints, variables
  15. from ..exc import raise_observed_exception, unimplemented
  16. from ..utils import (
  17. cmp_name_to_op_mapping,
  18. common_constant_types,
  19. istype,
  20. np,
  21. raise_args_mismatch,
  22. raise_on_overridden_hash,
  23. )
  24. from .base import ValueMutationNew, VariableTracker
  25. if TYPE_CHECKING:
  26. from torch._dynamo.symbolic_convert import InstructionTranslator
  27. from .functions import UserFunctionVariable
  28. class ConstantVariable(VariableTracker):
  29. """
  30. Variable tracker for Python literals and basic immutable types, with automatic
  31. routing support for collection types (lists, tuples, sets, etc.).
  32. The create() method intelligently constructs appropriate variable types for
  33. nested collections.
  34. """
  35. @overload
  36. @staticmethod
  37. def create(value: None) -> Never: ...
  38. @overload
  39. @staticmethod
  40. def create(value: bool) -> "ConstantVariable": ...
  41. # TODO: Refactor to make these return ConstantVariable
  42. @overload
  43. @staticmethod
  44. def create(value: Any, **kwargs: Any) -> VariableTracker: ...
  45. @staticmethod
  46. def create(value: Any, **kwargs: Any) -> VariableTracker:
  47. """
  48. Create a `ConstantVariable` based on the given value, and supports
  49. automatic routing for collection types like `tuple` (in which case we'd
  50. create `ConstantVariable` for the leaf items).
  51. NOTE: the caller must install the proper guards if needed; most often
  52. the guard will be `CONSTANT_MATCH`.
  53. """
  54. source = kwargs.get("source")
  55. # Routing for supported collection literals.
  56. if isinstance(value, set):
  57. items = [ConstantVariable.create(x) for x in value]
  58. return variables.SetVariable(items, **kwargs) # type: ignore[arg-type]
  59. elif isinstance(value, frozenset):
  60. items = [ConstantVariable.create(x) for x in value]
  61. return variables.FrozensetVariable(items, **kwargs) # type: ignore[arg-type]
  62. elif isinstance(value, slice):
  63. slice_args = (value.start, value.stop, value.step)
  64. slice_args_vars = tuple(ConstantVariable.create(arg) for arg in slice_args)
  65. return variables.SliceVariable(slice_args_vars, **kwargs)
  66. elif isinstance(value, (list, tuple)):
  67. items = []
  68. for i, x in enumerate(value):
  69. item_source = GetItemSource(source, i) if source else None
  70. items.append(
  71. ConstantVariable.create(
  72. x,
  73. source=item_source,
  74. )
  75. )
  76. return variables.BaseListVariable.cls_for(type(value))(items, **kwargs)
  77. return ConstantVariable(value, **kwargs)
  78. def __init__(self, value: Any, **kwargs: Any) -> None:
  79. super().__init__(**kwargs)
  80. assert ConstantVariable.is_base_literal(value), f"""
  81. Cannot construct `ConstantVariable` for value of type {type(value)}.
  82. This failure likely due to PyTorch-internal use of `ConstantVariable` on
  83. non-literal python values, please try using `VariableTracker.build` instead. If
  84. you believe it's a necessary and legitimate use case (the value is immutable and
  85. can't easily be represented with another `VariableTracker` class), please add
  86. its type to `common_constant_types`.
  87. """
  88. if np is not None and isinstance(value, np.number):
  89. self.value = value.item()
  90. else:
  91. self.value = value
  92. def as_proxy(self) -> Any:
  93. return self.value
  94. def __repr__(self) -> str:
  95. return f"ConstantVariable({type(self.value).__name__}: {repr(self.value)})"
  96. def as_python_constant(self) -> Any:
  97. return self.value
  98. def is_python_constant(self) -> Literal[True]:
  99. return True
  100. def is_symnode_like(self) -> bool:
  101. return isinstance(self.value, (int, bool))
  102. def is_constant_match(self, *values: Any) -> bool:
  103. return self.value in values
  104. def is_constant_none(self) -> bool:
  105. return self.value is None
  106. @property
  107. def items(self) -> list[VariableTracker]:
  108. """
  109. Need this when adding a BaseListVariable and a ConstantVariable together.
  110. Happens in detectron2.
  111. """
  112. return self.unpack_var_sequence(tx=None)
  113. def getitem_const(
  114. self, tx: "InstructionTranslator", arg: VariableTracker
  115. ) -> VariableTracker:
  116. return ConstantVariable.create(
  117. self.value[arg.as_python_constant()],
  118. )
  119. @staticmethod
  120. def is_base_literal(obj: object) -> bool:
  121. return type(obj) in common_constant_types
  122. @staticmethod
  123. def is_literal(obj: object, cache: dict[int, object] | None = None) -> bool:
  124. if cache is None:
  125. cache = {}
  126. if id(obj) in cache:
  127. # no-op if there is a cyclical reference
  128. return True
  129. if type(obj) in (list, tuple, set, frozenset, torch.Size):
  130. cache[id(obj)] = obj
  131. return all(ConstantVariable.is_literal(x, cache) for x in obj) # type: ignore[attr-defined]
  132. return ConstantVariable.is_base_literal(obj)
  133. def unpack_var_sequence(
  134. self, tx: Optional["InstructionTranslator"]
  135. ) -> list[VariableTracker]:
  136. try:
  137. return [ConstantVariable.create(x) for x in self.as_python_constant()]
  138. except TypeError as e:
  139. raise NotImplementedError from e
  140. def const_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
  141. if not hasattr(self.value, name):
  142. raise_observed_exception(AttributeError, tx, args=[name])
  143. member = getattr(self.value, name)
  144. if callable(member):
  145. raise NotImplementedError
  146. return member
  147. def call_method(
  148. self,
  149. tx: "InstructionTranslator",
  150. name: str,
  151. args: list[VariableTracker],
  152. kwargs: dict[str, VariableTracker],
  153. ) -> VariableTracker:
  154. from .tensor import SymNodeVariable
  155. if name == "format" and istype(self.value, str):
  156. return variables.BuiltinVariable(str.format).call_function(
  157. tx, [self, *args], kwargs
  158. )
  159. elif name == "join" and istype(self.value, str):
  160. if kwargs or len(args) != 1:
  161. raise_args_mismatch(
  162. tx,
  163. name,
  164. "1 args and 0 kwargs",
  165. f"{len(args)} args and {len(kwargs)} kwargs",
  166. )
  167. arg_unpacked = args[0].force_unpack_var_sequence(tx)
  168. try:
  169. arg_const = [x.as_python_constant() for x in arg_unpacked]
  170. return ConstantVariable.create(self.value.join(arg_const))
  171. except NotImplementedError:
  172. return super().call_method(tx, name, args, kwargs)
  173. elif name == "__iter__" and istype(self.value, str):
  174. # this could be some generic iterator to avoid the circular import,
  175. # but ListIterator does what we want
  176. from .lists import ListIteratorVariable
  177. return ListIteratorVariable(
  178. self.unpack_var_sequence(tx), mutation_type=ValueMutationNew()
  179. )
  180. if any(isinstance(x, SymNodeVariable) for x in args):
  181. # Promote to SymNodeVariable for operations involving dynamic shapes.
  182. return variables.SymNodeVariable.create(
  183. tx, self.as_proxy(), self.value
  184. ).call_method(tx, name, args, kwargs)
  185. try:
  186. const_args = [a.as_python_constant() for a in args]
  187. const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
  188. except NotImplementedError:
  189. return super().call_method(tx, name, args, kwargs)
  190. if isinstance(self.value, str) and name in str.__dict__:
  191. method = getattr(self.value, name)
  192. try:
  193. return ConstantVariable.create(method(*const_args, **const_kwargs))
  194. except Exception as e:
  195. raise_observed_exception(type(e), tx)
  196. elif isinstance(self.value, (float, int)) and hasattr(self.value, name):
  197. if not (args or kwargs):
  198. try:
  199. return ConstantVariable.create(getattr(self.value, name)())
  200. except (OverflowError, ValueError) as exc:
  201. raise_observed_exception(
  202. type(exc),
  203. tx,
  204. args=list(map(ConstantVariable.create, exc.args)),
  205. )
  206. if (
  207. hasattr(operator, name)
  208. and len(args) == 1
  209. and args[0].is_python_constant()
  210. ):
  211. add_target = const_args[0]
  212. op = getattr(operator, name)
  213. if isinstance(
  214. add_target, (torch.SymBool, torch.SymFloat, torch.SymInt)
  215. ):
  216. # Addition between a non sym and sym makes a sym
  217. proxy = tx.output.create_proxy(
  218. "call_function", op, (self.value, add_target), {}
  219. )
  220. return SymNodeVariable.create(tx, proxy, add_target)
  221. else:
  222. try:
  223. return ConstantVariable.create(op(self.value, add_target))
  224. except Exception as e:
  225. raise_observed_exception(
  226. type(e), tx, args=list(map(ConstantVariable.create, e.args))
  227. )
  228. elif isinstance(self.value, bytes) and name == "decode":
  229. method = getattr(self.value, name)
  230. return ConstantVariable.create(method(*const_args, **const_kwargs))
  231. elif type(self.value) is complex and name in complex.__dict__:
  232. method = getattr(self.value, name)
  233. try:
  234. return ConstantVariable.create(method(*const_args, **const_kwargs))
  235. except Exception as e:
  236. raise_observed_exception(type(e), tx)
  237. if name == "__len__" and not (args or kwargs):
  238. try:
  239. # pyrefly: ignore [bad-argument-type]
  240. return ConstantVariable.create(len(self.value))
  241. except TypeError as e:
  242. raise_observed_exception(type(e), tx, args=list(e.args))
  243. elif name == "__round__" and len(args) == 1 and args[0].is_python_constant():
  244. try:
  245. return ConstantVariable.create(
  246. # pyrefly: ignore [no-matching-overload]
  247. round(self.value, args[0].as_python_constant())
  248. )
  249. except Exception as e:
  250. raise_observed_exception(
  251. type(e), tx, args=list(map(ConstantVariable.create, e.args))
  252. )
  253. elif name == "__contains__" and len(args) == 1 and args[0].is_python_constant():
  254. assert not kwargs
  255. search = args[0].as_python_constant()
  256. try:
  257. # pyrefly: ignore [not-iterable, unsupported-operation]
  258. result = search in self.value
  259. return ConstantVariable.create(result)
  260. except TypeError as e:
  261. raise_observed_exception(
  262. type(e), tx, args=list(map(ConstantVariable.create, e.args))
  263. )
  264. return super().call_method(tx, name, args, kwargs)
  265. def call_tree_map(
  266. self,
  267. tx: "InstructionTranslator",
  268. tree_map_fn: "UserFunctionVariable",
  269. map_fn: VariableTracker,
  270. rest: Sequence[VariableTracker],
  271. tree_map_kwargs: dict[str, VariableTracker],
  272. ) -> VariableTracker:
  273. if self.value is None:
  274. none_is_leaf_var = tree_map_kwargs.get("none_is_leaf")
  275. if none_is_leaf_var is not None:
  276. try:
  277. none_is_leaf = bool(none_is_leaf_var.as_python_constant())
  278. except NotImplementedError:
  279. return self._tree_map_fallback(
  280. tx,
  281. tree_map_fn,
  282. map_fn,
  283. rest,
  284. tree_map_kwargs,
  285. )
  286. else:
  287. tree_map_module = getattr(
  288. getattr(tree_map_fn, "fn", None), "__module__", ""
  289. )
  290. # torch.utils._pytree and torch.utils._cxx_pytree treat None as a leaf
  291. # by default, while optree keeps it as an internal node unless
  292. # none_is_leaf=True is provided.
  293. none_is_leaf = not tree_map_module.startswith("optree")
  294. if none_is_leaf:
  295. return map_fn.call_function(tx, [self, *rest], {})
  296. else:
  297. for other in rest:
  298. if not other.is_constant_none():
  299. return self._tree_map_fallback(
  300. tx,
  301. tree_map_fn,
  302. map_fn,
  303. rest,
  304. tree_map_kwargs,
  305. )
  306. return self.clone()
  307. if isinstance(self.value, (int, float, bool, complex, str, bytes, torch.dtype)):
  308. return map_fn.call_function(tx, [self, *rest], {})
  309. return super().call_tree_map(
  310. tx,
  311. tree_map_fn,
  312. map_fn,
  313. rest,
  314. tree_map_kwargs,
  315. )
  316. @override
  317. def call_obj_hasattr(
  318. self, tx: "InstructionTranslator", name: str
  319. ) -> "ConstantVariable":
  320. result = hasattr(self.value, name)
  321. return variables.ConstantVariable.create(result)
  322. def is_python_hashable(self) -> Literal[True]:
  323. return True
  324. def get_python_hash(self) -> int:
  325. return hash(self.value)
  326. def is_python_equal(self, other: object) -> bool:
  327. # Could be an EnumVariable as well
  328. from .tensor import SymNodeVariable
  329. if isinstance(other, SymNodeVariable):
  330. return self.as_python_constant() == other.evaluate_expr()
  331. return (
  332. isinstance(other, VariableTracker)
  333. and self.as_python_constant() == other.as_python_constant()
  334. )
  335. CONSTANT_VARIABLE_NONE = ConstantVariable(None)
  336. class EnumVariable(VariableTracker):
  337. """VariableTracker for enum.Enum and enum.IntEnum instances
  338. Provides specialized handling for Python enum types, supporting
  339. both standard Enum and IntEnum with proper value tracking and comparison.
  340. """
  341. def __init__(self, value: Union[enum.Enum, enum.IntEnum], **kwargs: Any) -> None:
  342. super().__init__(**kwargs)
  343. self.value = value
  344. @classmethod
  345. def create(
  346. cls, cls_type: Any, value_vt: VariableTracker, options: Any
  347. ) -> "EnumVariable":
  348. if value_vt.is_python_constant():
  349. for member in list(cls_type):
  350. if member.value == value_vt.as_python_constant():
  351. return cls(member, **options)
  352. unimplemented(
  353. gb_type="Failed to construct Enum variable",
  354. context=f"value: {value_vt}, allowed enum values: {list(cls_type)}",
  355. explanation="Attempted to construct an Enum value that is non-constant (e.g. int, string) "
  356. "or is not an acceptable value for the Enum. "
  357. f"Acceptable values for Enum `{cls_type}`: {list(cls_type)}.",
  358. hints=[*graph_break_hints.USER_ERROR, *graph_break_hints.SUPPORTABLE],
  359. )
  360. def as_proxy(self) -> Union[enum.Enum, int]:
  361. if isinstance(self.value, int):
  362. return int(self.value) # convert IntEnum to a normal int
  363. return self.value
  364. def __repr__(self) -> str:
  365. return f"EnumVariable({type(self.value)})"
  366. def as_python_constant(self) -> Union[enum.Enum, enum.IntEnum]:
  367. return self.value
  368. def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
  369. if not hasattr(self.value, name):
  370. raise NotImplementedError
  371. if name in cmp_name_to_op_mapping:
  372. return variables.GetAttrVariable(self, name)
  373. member = getattr(self.value, name)
  374. source = self.source and AttrSource(self.source, name)
  375. return VariableTracker.build(tx, member, source=source)
  376. def is_python_hashable(self) -> Literal[True]:
  377. raise_on_overridden_hash(self.value, self)
  378. return True
  379. def get_python_hash(self) -> int:
  380. return hash(self.as_python_constant())
  381. def is_python_equal(self, other: object) -> bool:
  382. return (
  383. isinstance(other, VariableTracker)
  384. and self.as_python_constant() == other.as_python_constant()
  385. )