sym_node.py 63 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956
  1. # mypy: allow-untyped-defs
  2. from __future__ import annotations
  3. """
  4. This file does three things:
  5. - Contains the definition of SymNode
  6. - Installs all the magic methods into SymBool, SymFloat, SymFloat at import time
  7. - Does not depend on sympy at import time
  8. As this file is imported from within torch/__init__.py we do not want it to depend on SymPy
  9. to avoid having to load SymPy at import time, as doing so is *very* slow.
  10. """
  11. import builtins
  12. import functools
  13. import inspect
  14. import itertools
  15. import logging
  16. import math
  17. import operator
  18. import sys
  19. from functools import lru_cache, update_wrapper
  20. from typing import Optional, TYPE_CHECKING, Union
  21. import torch
  22. import torch._logging.structured as structured
  23. # NB: The sym_* functions are used via getattr() and must be imported here.
  24. from torch import ( # noqa: F401
  25. sym_float,
  26. sym_ite,
  27. sym_max,
  28. sym_min,
  29. sym_not,
  30. SymBool,
  31. SymFloat,
  32. SymInt,
  33. )
  34. from torch._logging import dtrace_structured
  35. if TYPE_CHECKING:
  36. from torch.fx.experimental.symbolic_shapes import ShapeEnv
  37. log = logging.getLogger(__name__)
  38. sym_node_log = torch._logging.getArtifactLogger(__name__, "sym_node")
  39. # Sentinel value to indicate "don't compute hint" vs actual None
  40. # When passed as hint to SymNode, it means we already know hint is unavailable
  41. # and should not waste time calling compute_hint()
  42. _NO_HINT: object = object()
  43. # Type alias for hint values (including the sentinel)
  44. HintType = bool | float | int | None
  45. __all__ = ["SymNode", "method_to_operator", "magic_methods", "DynamicInt"]
  46. from torch.types import py_sym_types as SymTypes
  47. def _to_symtype(t):
  48. if t is bool:
  49. return SymBool
  50. if t is int:
  51. return SymInt
  52. if t is float:
  53. return SymFloat
  54. return t
  55. # TODO: An incomplete list
  56. # 1. Set variables to be equal when we do equality
  57. # 2. Specialize on 0/1 when we do subtraction
  58. class SymNode:
  59. """
  60. This is a type erased SymInt/SymFloat which we use to do actual operations.
  61. End users don't touch this. Magic methods are NOT defined on this object.
  62. """
  63. # Note [optimized_summation]: indicates that SymNode is an Add expression of the form
  64. # a + b + c + d... etc where all terms are unique symbols. This allows us to do some optimizations
  65. # for common patterns see _optimized_add.
  66. # The unfortunate reason we have this here is because sympy sets __slots__ = () for add expression,
  67. # so we cannot add the attribute directly to the sympy expression. Furthermore, we cannot use it as
  68. # a weak dictionary key either! So instead, we attach the attribute here to the SymNode.
  69. _optimized_summation: bool = False
  70. def __init__(
  71. self,
  72. expr,
  73. shape_env,
  74. pytype,
  75. hint: Optional[Union[int, float, bool]],
  76. constant=None,
  77. fx_node=None,
  78. optimized_summation=False,
  79. ):
  80. self._expr = expr
  81. self.shape_env = shape_env
  82. self.pytype = pytype
  83. self._optimized_summation = optimized_summation
  84. # What's the difference between hint and constant?
  85. #
  86. # - A constant is known to be invariant across invocations of the model;
  87. # it will always be this value. We only really know this when we
  88. # encounter an honest-to-goodness literal (when wrapping it into
  89. # a SymNode, we set constant.) Most of the time, constant is None
  90. #
  91. # - A hint is a *particular* value from the particular run we are
  92. # tracing, but it may vary the next time around. It's useful to
  93. # keep this around, as if we need a concrete value from a SymNode,
  94. # we will return the hint and guard on the expression that produced
  95. # it giving the same hint next time around. The hint is not
  96. # guaranteed to be set either: if you have an unbacked SymNode,
  97. # there won't be any hint; it was the result of some tensor-dependent
  98. # computation, but we don't know what it actually is because we
  99. # haven't actually run the tensor computation.
  100. #
  101. # If _hint is None, we will query maybe_evaluate_static(compute_hint=True)
  102. # in hopes that we've learned enough about the unbacked symints to
  103. # discharge the hint; otherwise, you're likely to just error out.
  104. #
  105. # (A previous version of this system had some optimizations to only
  106. # recompute when it was possible we had learned enough about the
  107. # unbacked symint that a hint was now possible, but as we added more
  108. # potential refinements to unbacked symints this got harder to keep
  109. # in sync, so we've deleted it for now.)
  110. def compute_hint():
  111. from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols
  112. # This occasionally gets exercised by, e.g.,
  113. # convert_shape_to_symint. It's just a nicety so you don't HAVE
  114. # to have a correct hint on hand when making a SymNode.
  115. # Don't attempt to compute for unbacked, this can be quite
  116. # expensive.
  117. if has_free_unbacked_symbols(self.expr):
  118. return None
  119. hint = self.shape_env._maybe_evaluate_static(self.expr, compute_hint=True)
  120. if hint is not None:
  121. hint = self.pytype(hint) if not isinstance(hint, SymTypes) else hint
  122. return hint
  123. if hint is _NO_HINT:
  124. # Caller explicitly indicates hint is unavailable, don't compute
  125. hint = None
  126. elif hint is not None:
  127. if not (type(hint) is pytype or type(hint) is _to_symtype(pytype)):
  128. raise AssertionError(
  129. "Cannot create SymNode of type "
  130. f"{pytype} with incompatible hint of type {type(hint)}"
  131. )
  132. if self.shape_env and self.shape_env._translation_validation_enabled:
  133. # This is technically not TV, but this assert is expensive so
  134. # let's only do it when we're already doing expensive things
  135. computed_hint = compute_hint()
  136. if hint != computed_hint:
  137. raise AssertionError(f"{hint} != {computed_hint} (for {self.expr})")
  138. else:
  139. hint = compute_hint()
  140. self._hint = hint
  141. self.constant: Optional[Union[int, float, bool]] = constant
  142. # Record the FX node of the current node if we are doing translation
  143. # validation. They will be used for building the input assertions for
  144. # the translation validation problem.
  145. tx_validation_en = (
  146. self.shape_env and self.shape_env._translation_validation_enabled
  147. )
  148. self.fx_node = tx_validation_en and fx_node
  149. def with_shape_env(self, shape_env: ShapeEnv) -> SymNode:
  150. return SymNode(
  151. self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node
  152. )
  153. def _value_eq(self, other: SymNode) -> bool:
  154. # Purposely don't include the shape_env in the eq.
  155. return (
  156. self._expr == other._expr
  157. and self.pytype == other.pytype
  158. and self._hint == other._hint
  159. and self.constant == other.constant
  160. and self.fx_node == other.fx_node
  161. )
  162. def _value_hash(self) -> int:
  163. # Purposely don't include the shape_env in the hash.
  164. return hash((self._expr, self.pytype, self._hint, self.constant, self.fx_node))
  165. @property
  166. def expr(self):
  167. return self.shape_env.replace(self._expr)
  168. @property
  169. def hint(self):
  170. return self._hint
  171. def has_hint(self):
  172. return self._hint is not None
  173. def require_hint(self, fallback=None):
  174. from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
  175. if self._hint is None:
  176. if fallback is not None:
  177. # Say we have some expr like 2*u0 + s0
  178. # The hint will be None, since the expr contains at least 1 unbacked.
  179. # We will:
  180. # - replace every backed free symbol with its corresponding hint
  181. # - replace every unbacked free symbol with the fallback
  182. # - regenerate the expression with those symbol replacements
  183. # Note: this is not really complete either, since right now
  184. # this logic does not take into account any value ranges
  185. # for the unbacked symints, we may need to beef it up at some point.
  186. unbacked_symbols = free_unbacked_symbols(self.expr)
  187. replacements = {
  188. s: fallback
  189. if s in unbacked_symbols
  190. else self.shape_env.backed_var_to_val[s]
  191. for s in self.expr.free_symbols
  192. }
  193. return int(self.expr.xreplace(replacements))
  194. # NB: we expect this to raise
  195. return self.shape_env.size_hint(self.expr)
  196. return self._hint
  197. def maybe_as_int(self):
  198. if self.expr.is_number:
  199. return int(self.expr)
  200. else:
  201. return None
  202. # NB: This does conversions, not sure if this is good or not
  203. def maybe_as_float(self):
  204. import sympy
  205. if isinstance(self.expr, sympy.Float):
  206. return float(self.expr)
  207. else:
  208. return None
  209. def maybe_as_bool(self):
  210. import sympy
  211. if self.expr is sympy.true:
  212. return True
  213. elif self.expr is sympy.false:
  214. return False
  215. else:
  216. return None
  217. def is_int(self):
  218. return self.pytype is int
  219. def is_float(self):
  220. return self.pytype is float
  221. def is_bool(self):
  222. return self.pytype is bool
  223. def is_nested_int(self):
  224. # Unbacked SymInts cannot be nested int today
  225. return (
  226. self._hint is not None
  227. and isinstance(self._hint, SymInt)
  228. and self._hint.node.is_nested_int()
  229. )
  230. def wrap_int(self, num):
  231. if type(num) is not int:
  232. raise AssertionError(f"Expected int, got {type(num)}")
  233. import sympy
  234. return SymNode(
  235. sympy.Integer(num), self.shape_env, int, num, constant=num, fx_node=num
  236. )
  237. def wrap_float(self, num):
  238. if type(num) is not float:
  239. raise AssertionError(f"Expected float, got {type(num)}")
  240. import sympy
  241. return SymNode(
  242. sympy.Float(num), self.shape_env, float, num, constant=num, fx_node=num
  243. )
  244. def wrap_bool(self, num):
  245. if type(num) is not bool:
  246. raise AssertionError(f"Expected bool, got {type(num)}")
  247. import sympy
  248. return SymNode(
  249. sympy.true if num else sympy.false,
  250. self.shape_env,
  251. bool,
  252. num,
  253. constant=num,
  254. fx_node=num,
  255. )
  256. def clone(self):
  257. return self
  258. def str(self):
  259. return f"{self.expr}"
  260. def __str__(self):
  261. return self.str()
  262. def __repr__(self):
  263. rep = [
  264. f"SymNode({self._expr}, shape_env={self.shape_env}, pytype={self.pytype}",
  265. ]
  266. if self._hint is not None:
  267. rep.append(f"hint={self._hint}")
  268. if self.constant is not None:
  269. rep.append(f"constant={self.constant}")
  270. if self.fx_node is not None:
  271. rep.append(f"fx_node={self.fx_node}")
  272. return ", ".join(rep) + ")"
  273. def _graph_repr(self) -> builtins.str:
  274. # Representation used by GraphModule to create a pythonic version of a graph
  275. return self.str()
  276. # These methods call the metaprogrammed methods, they're hand written
  277. # here so we get good stack traces
  278. def abs(self) -> SymNode:
  279. return self._abs() # type: ignore[attr-defined]
  280. def pos(self) -> SymNode:
  281. return self._pos() # type: ignore[attr-defined]
  282. def round(self, ndigits=None) -> SymNode:
  283. return self._round(ndigits) # type: ignore[attr-defined]
  284. def trunc(self) -> SymNode:
  285. return self._trunc() # type: ignore[attr-defined]
  286. def add(self, other) -> SymNode:
  287. return self._add(other) # type: ignore[attr-defined]
  288. def sub(self, other) -> SymNode:
  289. return self._sub(other) # type: ignore[attr-defined]
  290. def mul(self, other) -> SymNode:
  291. return self._mul(other) # type: ignore[attr-defined]
  292. def mod(self, other) -> SymNode:
  293. return self._mod(other) # type: ignore[attr-defined]
  294. def float_pow(self, other) -> SymNode:
  295. return self._float_pow(other) # type: ignore[attr-defined]
  296. def pow_by_natural(self, other) -> SymNode:
  297. return self._pow_by_natural(other) # type: ignore[attr-defined]
  298. def and_(self, other) -> SymNode:
  299. return self._and_(other) # type: ignore[attr-defined]
  300. def or_(self, other) -> SymNode:
  301. return self._or_(other) # type: ignore[attr-defined]
  302. def float_truediv(self, other) -> SymNode:
  303. return self._float_truediv(other) # type: ignore[attr-defined]
  304. def int_truediv(self, other) -> SymNode:
  305. return self._int_truediv(other) # type: ignore[attr-defined]
  306. def int_floordiv(self, other) -> SymNode:
  307. return self._int_floordiv(other) # type: ignore[attr-defined]
  308. def lshift(self, other) -> SymNode:
  309. return self._lshift(other) # type: ignore[attr-defined]
  310. def rshift(self, other) -> SymNode:
  311. return self._rshift(other) # type: ignore[attr-defined]
  312. def sym_not(self) -> SymNode: # noqa: F811
  313. return self._sym_not() # type: ignore[attr-defined]
  314. def eq(self, other) -> SymNode:
  315. return self._eq(other) # type: ignore[attr-defined]
  316. def ne(self, other) -> SymNode:
  317. return self._ne(other) # type: ignore[attr-defined]
  318. def gt(self, other) -> SymNode:
  319. return self._gt(other) # type: ignore[attr-defined]
  320. def lt(self, other) -> SymNode:
  321. return self._lt(other) # type: ignore[attr-defined]
  322. def le(self, other) -> SymNode:
  323. return self._le(other) # type: ignore[attr-defined]
  324. def ge(self, other) -> SymNode:
  325. return self._ge(other) # type: ignore[attr-defined]
  326. def floor(self) -> SymNode:
  327. return self._floor() # type: ignore[attr-defined]
  328. def is_integer(self) -> SymNode:
  329. return self._is_integer() # type: ignore[attr-defined]
  330. def sym_float(self) -> SymNode: # noqa: F811
  331. return self._sym_float() # type: ignore[attr-defined]
  332. def sym_int(self) -> SymNode:
  333. return self._sym_int() # type: ignore[attr-defined]
  334. def ceil(self) -> SymNode:
  335. return self._ceil() # type: ignore[attr-defined]
  336. def neg(self) -> SymNode:
  337. return self._neg() # type: ignore[attr-defined]
  338. def sym_min(self, other) -> SymNode: # noqa: F811
  339. return self._sym_min(other) # type: ignore[attr-defined]
  340. def sym_max(self, other) -> SymNode: # noqa: F811
  341. return self._sym_max(other) # type: ignore[attr-defined]
  342. def sym_ite(self, then_val, else_val) -> SymNode:
  343. return self._sym_ite(then_val, else_val) # type: ignore[attr-defined]
  344. def is_contiguous(self, sizes, strides) -> SymNode:
  345. return self._is_contiguous(sizes, strides) # type: ignore[attr-defined]
  346. def is_channels_last_contiguous_2d(self, sizes, strides) -> SymNode:
  347. return self._is_channels_last_contiguous_2d(sizes, strides) # type: ignore[attr-defined]
  348. def is_channels_last_contiguous_3d(self, sizes, strides) -> SymNode:
  349. return self._is_channels_last_contiguous_3d(sizes, strides) # type: ignore[attr-defined]
  350. def is_channels_last_strides_2d(self, sizes, strides) -> SymNode:
  351. return self._is_channels_last_strides_2d(sizes, strides) # type: ignore[attr-defined]
  352. def is_channels_last_strides_3d(self, sizes, strides) -> SymNode:
  353. return self._is_channels_last_strides_3d(sizes, strides) # type: ignore[attr-defined]
  354. def is_non_overlapping_and_dense_indicator(self, sizes, strides) -> SymNode:
  355. return self._is_non_overlapping_and_dense_indicator(sizes, strides) # type: ignore[attr-defined]
  356. # Make C++ happy
  357. def sym_or(self, other):
  358. return self.or_(other)
  359. def sym_and(self, other):
  360. return self.and_(other)
  361. # Integer bitwise ops
  362. def bitwise_and(self, other):
  363. return self._bitwise_and(other) # type: ignore[attr-defined]
  364. def bitwise_or(self, other):
  365. return self._bitwise_or(other) # type: ignore[attr-defined]
  366. def bitwise_xor(self, other):
  367. return self._bitwise_xor(other) # type: ignore[attr-defined]
  368. # There is no int_truediv available from C++
  369. def truediv(self, other):
  370. return self.float_truediv(other)
  371. def floordiv(self, other) -> SymNode:
  372. return self.int_floordiv(other)
  373. # We didn't bind integer pow in C++
  374. def pow(self, other):
  375. return self.float_pow(other)
  376. def is_non_overlapping_and_dense(self, sizes, strides):
  377. return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(
  378. to_node(self, 1)
  379. ) # type: ignore[attr-defined]
  380. def int_(self):
  381. return self.guard_int("", 0) # NB: uses Python backtrace
  382. # This one is currently done by hand, but if we add other variadic
  383. # functions consider factoring it out to be metaprogrammed too. Note that
  384. # some load bearing logic is directly in torch.sym_sum
  385. def sym_sum(self, args) -> SymNode:
  386. import sympy
  387. # Inner impl
  388. from torch.fx.experimental.proxy_tensor import (
  389. get_proxy_mode,
  390. handle_sym_dispatch,
  391. )
  392. if get_proxy_mode():
  393. return to_node(
  394. self,
  395. handle_sym_dispatch(
  396. torch.sym_sum,
  397. (tuple(wrap_node(a) for a in args),),
  398. {},
  399. ),
  400. )
  401. exprs = [a.expr for a in args]
  402. out = sympy.Add(*exprs)
  403. size_hints = []
  404. out_hint = None
  405. for a in args:
  406. if a.hint is None:
  407. break
  408. size_hints.append(a.hint)
  409. else:
  410. out_hint = sum(size_hints)
  411. fx_node, _ = self.shape_env._create_fx_call_function(
  412. torch.sym_sum, (tuple(a.fx_node for a in args),)
  413. )
  414. # NB: Only for integers!
  415. return SymNode(out, self.shape_env, int, out_hint, fx_node=fx_node)
  416. def evaluate(self, size_oblivious=False):
  417. return self.shape_env.evaluate_sym_node(self, size_oblivious)
  418. # You can manually trigger a guard with this function
  419. def guard_int(self, file, line):
  420. # TODO: use the file/line for some useful diagnostic on why a
  421. # guard occurred
  422. r = self.evaluate()
  423. try:
  424. return int(r)
  425. except Exception:
  426. log.warning("Failed to convert to int: %s", r)
  427. raise
  428. def guard_float(self, file, line):
  429. # TODO: use the file/line for some useful diagnostic on why a
  430. # guard occurred
  431. r = self.evaluate()
  432. try:
  433. return float(r)
  434. except Exception:
  435. log.warning("Failed to convert to float: %s", r)
  436. raise
  437. def guard_bool(self, file, line):
  438. # TODO: use the file/line for some useful diagnostic on why a
  439. # guard occurred
  440. r = self.evaluate()
  441. try:
  442. return bool(r)
  443. except Exception:
  444. log.warning("Failed to convert to bool: %s", r)
  445. raise
  446. def expect_true(self, file, line):
  447. from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
  448. if (
  449. self.has_hint()
  450. and not free_unbacked_symbols(self.expr)
  451. and not self.shape_env.prefer_deferred_runtime_asserts_over_guards
  452. ):
  453. # OK to generate guards
  454. return self.guard_bool(file, line)
  455. # Generate a deferred runtime assert (this might actually end up doing
  456. # a regular guard if we can!)
  457. # TODO: file/line here is very important, because the assert has been
  458. # deferred so you can't backtrace easily
  459. return self.shape_env.guard_or_defer_runtime_assert(
  460. self.expr, f"{file}:{line}", fx_node=self.fx_node
  461. )
  462. def statically_known_true(self, file, line):
  463. from torch.fx.experimental.symbolic_shapes import statically_known_true
  464. if not self.is_bool():
  465. raise AssertionError("Expected bool type")
  466. return statically_known_true(SymBool(self))
  467. def guard_size_oblivious(self, file, line):
  468. """
  469. Like guard_bool, but if we encounter unbacked symbols, if those symbols
  470. are size-like, we will treat them as >= 2 for the purposes of the analysis.
  471. This CHANGES the runtime semantics, but all size-oblivious sites have been
  472. audited to ensure that the runtime semantics don't change in a material way.
  473. Acceptable runtime semantic changes are, e.g., squeeze() no longer dropping
  474. an unbacked one size, or a tensor reporting as non-contiguous even if it's
  475. contiguous if it would have been reported contiguous due to being empty.
  476. """
  477. # TODO: use the file/line for some useful diagnostic on why a
  478. # guard occurred
  479. r = self.evaluate(size_oblivious=True)
  480. try:
  481. return bool(r)
  482. except Exception:
  483. log.warning("Failed to convert to bool: %s", r)
  484. raise
  485. def guard_or_false(self, file, line):
  486. from torch.fx.experimental.symbolic_shapes import guard_or_false
  487. if not self.is_bool():
  488. raise AssertionError("Expected bool type")
  489. return guard_or_false(SymBool(self))
  490. def guard_or_true(self, file, line):
  491. from torch.fx.experimental.symbolic_shapes import guard_or_true
  492. if not self.is_bool():
  493. raise AssertionError("Expected bool type")
  494. return guard_or_true(SymBool(self))
  495. def bool_(self):
  496. return self.guard_bool("", 0)
  497. def is_symbolic(self):
  498. return True
  499. def nested_int(self):
  500. return None
  501. def is_constant(self):
  502. return False
  503. class _DynamicScalar:
  504. def __new__(cls, *args):
  505. if cls is _DynamicScalar:
  506. raise TypeError("_DynamicScalar is an abstract base class, use DynamicInt.")
  507. return super().__new__(cls, *args)
  508. class DynamicInt(_DynamicScalar, int):
  509. """
  510. User API for marking dynamic integers in `torch.compile`.
  511. Intended to be compatible with both compile and eager mode.
  512. Example usage::
  513. fn = torch.compile(f)
  514. x = DynamicInt(4)
  515. fn(x) # compiles x as a dynamic integer input; returns f(4)
  516. """
  517. def __new__(cls, val):
  518. if not isinstance(val, int):
  519. raise AssertionError(f"Expected int, got {type(val)}")
  520. obj = super().__new__(cls, int(val))
  521. return obj
  522. def __repr__(self):
  523. return f"DynamicInt({self.real})"
  524. def __floordiv__(self, other): # // was casting to int without these overrides?
  525. return DynamicInt(self.real // other)
  526. def __rfloordiv__(self, other):
  527. return DynamicInt(other // self.real)
  528. # TODO: this probably needs the sizes-strides eval functions
  529. METHOD_TO_OPERATOR = {
  530. "pos": operator.pos,
  531. "abs": operator.abs,
  532. "add": operator.add,
  533. "and": operator.and_,
  534. "bitwise_and": operator.and_,
  535. "ceil": math.ceil,
  536. "eq": operator.eq,
  537. "floor": math.floor,
  538. "trunc": math.trunc,
  539. "int_floordiv": operator.floordiv,
  540. "ge": operator.ge,
  541. "gt": operator.gt,
  542. "is_integer": lambda x: x.is_integer(),
  543. "le": operator.le,
  544. "lshift": operator.lshift,
  545. "lt": operator.lt,
  546. "mod": operator.mod,
  547. "mul": operator.mul,
  548. "ne": operator.ne,
  549. "neg": operator.neg,
  550. "or": operator.or_,
  551. "bitwise_or": operator.or_,
  552. "bitwise_xor": operator.xor,
  553. "float_pow": operator.pow,
  554. "pow_by_natural": operator.pow,
  555. "round": builtins.round,
  556. "rshift": operator.rshift,
  557. "sub": operator.sub,
  558. "sym_float": sym_float,
  559. "sym_ite": sym_ite,
  560. "sym_max": sym_max,
  561. "sym_min": sym_min,
  562. "sym_not": sym_not,
  563. "float_truediv": operator.truediv,
  564. "int_truediv": operator.truediv,
  565. }
  566. unary_magic_methods = {
  567. "abs",
  568. "sym_float",
  569. "sym_int",
  570. "ceil",
  571. "floor",
  572. "neg",
  573. "sym_not",
  574. "pos",
  575. "trunc",
  576. }
  577. # Adding math ops: sqrt, cos, sin, ...
  578. def _get_sym_node_fn(name):
  579. def fn(self):
  580. return getattr(self, f"_sym_{name}")()
  581. return fn
  582. math_op_names = (
  583. "sqrt",
  584. "cos",
  585. "cosh",
  586. "sin",
  587. "sinh",
  588. "tan",
  589. "tanh",
  590. "asin",
  591. "acos",
  592. "atan",
  593. "log2",
  594. )
  595. for name in math_op_names:
  596. sym_name = f"sym_{name}"
  597. priv_sym_name = f"_{sym_name}"
  598. setattr(SymNode, sym_name, _get_sym_node_fn(name))
  599. METHOD_TO_OPERATOR[sym_name] = getattr(torch, priv_sym_name)
  600. unary_magic_methods.add(sym_name)
  601. __all__.append(sym_name)
  602. # Unary methods that are not magic methods
  603. unary_nonmagic_methods = {
  604. "is_integer",
  605. }
  606. unary_methods = unary_magic_methods | unary_nonmagic_methods
  607. # Most methods are only registered on SymInt and SymFloat
  608. # Some methods are only be registered on SymBool
  609. only_bool_magic_methods = {"and", "or", "sym_not", "sym_ite"}
  610. # Methods that implicitly convert SymBool into SymInt
  611. bool_becomes_int_magic_methods = {"add", "sub", "mul"}
  612. # Methods that are also on SymBool, in addition to on SymInt and SymFloat
  613. also_bool_magic_methods = {"eq"}
  614. bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods
  615. # Methods that are only for float
  616. only_float_magic_methods = {"is_integer", "round", "sym_int", "sym_log2"}
  617. magic_methods_on_operator_with_trailing_underscore = {"and", "or"}
  618. # remap necessary because an op name can have a bitwise and boolean implementation
  619. bitwise_ops = {"bitwise_and": "and", "bitwise_or": "or", "bitwise_xor": "xor"}
  620. always_float_magic_methods = {"int_truediv", "float_truediv", "sym_float", "float_pow"}
  621. for name in math_op_names:
  622. sym_name = f"sym_{name}"
  623. always_float_magic_methods.add(sym_name)
  624. always_int_magic_methods = {"ceil", "floor", "trunc", "pow_by_natural"}
  625. always_bool_magic_methods = {
  626. "eq",
  627. "ne",
  628. "gt",
  629. "lt",
  630. "le",
  631. "ge",
  632. "and",
  633. "or",
  634. "sym_not",
  635. "is_non_overlapping_and_dense",
  636. "is_integer",
  637. }
  638. # Methods that have a `__foo__` as well as `__rfoo__`
  639. def _sympy_float_truediv(a, b):
  640. from torch.utils._sympy.functions import FloatTrueDiv
  641. return FloatTrueDiv(a, b)
  642. def _sympy_int_truediv(a, b):
  643. from torch.utils._sympy.functions import IntTrueDiv
  644. return IntTrueDiv(a, b)
  645. def _sympy_floordiv(a, b):
  646. from torch.utils._sympy.functions import FloorDiv
  647. return FloorDiv(a, b)
  648. def _sympy_mod(a, b):
  649. from torch.utils._sympy.functions import Mod, PythonMod
  650. if a.is_nonnegative and b.is_nonnegative:
  651. return Mod(a, b)
  652. else:
  653. return PythonMod(a, b)
  654. def _sympy_pow_by_natural(a, b):
  655. from torch.utils._sympy.functions import PowByNatural
  656. return PowByNatural(a, b)
  657. def _sympy_float_pow(a, b):
  658. from torch.utils._sympy.functions import FloatPow
  659. return FloatPow(a, b)
  660. def _sympy_and(a, b):
  661. import sympy
  662. return sympy.And(a, b)
  663. def _sympy_or(a, b):
  664. import sympy
  665. return sympy.Or(a, b)
  666. def _sympy_lshift(a, b):
  667. from torch.utils._sympy.functions import LShift
  668. return LShift(a, b)
  669. def _sympy_rshift(a, b):
  670. from torch.utils._sympy.functions import RShift
  671. return RShift(a, b)
  672. def _binary_search_insert_arg(ordered_args, new_arg):
  673. """
  674. If new_arg is found in ordered_args None is returned, else the new
  675. ordered_args with new_arg inserted
  676. """
  677. if len(ordered_args) == 0:
  678. return [new_arg]
  679. from sympy.core.basic import _args_sortkey as sort_key, Basic
  680. # Fast path when new_arg > ordered_args[-1].
  681. if sort_key(ordered_args[-1]) < sort_key(new_arg):
  682. return ordered_args + [new_arg]
  683. # Fast path when new_arg < ordered_args[0].
  684. if sort_key(ordered_args[0]) > sort_key(new_arg):
  685. return [new_arg] + ordered_args
  686. low, high = 0, len(ordered_args) - 1
  687. while low <= high:
  688. mid = (low + high) // 2
  689. compare_result = Basic.compare(ordered_args[mid], new_arg)
  690. if compare_result == 0:
  691. return None
  692. elif compare_result < 0:
  693. low = mid + 1
  694. else:
  695. high = mid - 1
  696. ordered_args.insert(low, new_arg)
  697. return ordered_args
  698. def _optimized_add(
  699. lhs, rhs, lhs_is_optimized_summation=False, rhs_is_optimized_summation=False
  700. ):
  701. """
  702. Custom optimization for Add used to optimize incremental binary summations of certain properties. The idea
  703. is when we know the expression is a summation of unique symbols all we need to know is the correct order of symbols,
  704. and no other optimizations are needed. We pass evaluate=false, with the correct order of args and save the following.
  705. 1. Avoid running other optimizations when the Add is constructed.
  706. 2. Manually figure out the order of the args for the new expression in log(n) comparisons instead of nLog(n)
  707. (comparing terms is expensive and shows in the profiles).
  708. The function returns a tuple of (1) a boolean that indicates whether the output is a summation of unique symbols,
  709. (2) the result sympy expression.
  710. """
  711. import sympy
  712. from sympy.core.basic import _args_sortkey as sortkey
  713. def make_optimized(ordered_args):
  714. if ordered_args is None:
  715. raise AssertionError("ordered_args is None")
  716. # Use _from_args directly to bypass _exec_constructor_postprocessors
  717. # which iterates over all args. This is safe because args are only
  718. # symbols or constants, which don't register postprocessors.
  719. # Pass is_commutative=True to avoid fuzzy_and check over all args.
  720. result = sympy.Add._from_args(ordered_args, is_commutative=True)
  721. return (True, result)
  722. from torch.utils._sympy.functions import _is_symbols_binary_summation
  723. lhs_is_optimized_summation |= _is_symbols_binary_summation(lhs)
  724. rhs_is_optimized_summation |= _is_symbols_binary_summation(rhs)
  725. if lhs_is_optimized_summation and rhs_is_optimized_summation:
  726. # (a0+a1..) + (a2+a3..) => (a0+a1+a2+a3)
  727. if sortkey(lhs._args[-1]) < sortkey(rhs._args[0]):
  728. return make_optimized(lhs._args + rhs._args)
  729. # (a2+a3..) + (a0+a1..) => (a0+a1+a2+a3)
  730. if sortkey(lhs._args[0]) > sortkey(rhs._args[-1]):
  731. return make_optimized(rhs._args + lhs._args)
  732. # (a1+a3) + (a0+a2) => (a0+a1+a2+a3)
  733. if len(lhs._args) <= 2 and len(rhs._args) <= 2:
  734. new_args = list(lhs._args)
  735. for a in rhs._args:
  736. new_args = _binary_search_insert_arg(new_args, a)
  737. if new_args is None:
  738. break
  739. # None means an element already exists.
  740. if new_args is not None:
  741. return make_optimized(new_args)
  742. # (a0+a2) + a1 => (a0+a1+a2)
  743. if lhs_is_optimized_summation and rhs.is_symbol:
  744. new_args = _binary_search_insert_arg(list(lhs._args), rhs)
  745. # None means an element already exists.
  746. if new_args is not None:
  747. return make_optimized(new_args)
  748. # a1 + (a0+a2)=> (a0+a1+a2)
  749. if rhs_is_optimized_summation and lhs.is_symbol:
  750. new_args = _binary_search_insert_arg(list(rhs._args), lhs)
  751. # None means an element already exists.
  752. if new_args is not None:
  753. return make_optimized(new_args)
  754. result = sympy.Add(lhs, rhs)
  755. return (_is_symbols_binary_summation(result), result)
  756. def _bitwise_and(a, b):
  757. from torch.utils._sympy.functions import BitwiseFn_bitwise_and
  758. return BitwiseFn_bitwise_and(a, b)
  759. def _bitwise_or(a, b):
  760. from torch.utils._sympy.functions import BitwiseFn_bitwise_or
  761. return BitwiseFn_bitwise_or(a, b)
  762. def _bitwise_xor(a, b):
  763. from torch.utils._sympy.functions import BitwiseFn_bitwise_xor
  764. return BitwiseFn_bitwise_xor(a, b)
  765. reflectable_magic_methods = {
  766. "add": operator.add,
  767. "sub": operator.sub,
  768. "mul": operator.mul,
  769. "mod": _sympy_mod,
  770. "pow_by_natural": _sympy_pow_by_natural,
  771. "float_pow": _sympy_float_pow,
  772. "and": _sympy_and,
  773. "bitwise_and": _bitwise_and,
  774. "or": _sympy_or,
  775. "bitwise_or": _bitwise_or,
  776. "bitwise_xor": _bitwise_xor,
  777. "float_truediv": _sympy_float_truediv,
  778. "int_truediv": _sympy_int_truediv,
  779. "int_floordiv": _sympy_floordiv,
  780. "lshift": _sympy_lshift,
  781. "rshift": _sympy_rshift,
  782. }
  783. def _floor_ceil_helper(a, fn):
  784. import sympy
  785. if isinstance(a, sympy.Mul):
  786. aa = a.args
  787. if len(aa) == 2 and isinstance(aa[0], sympy.Float) and aa[1].is_integer:
  788. coef = sympy.Integer(aa[0])
  789. if aa[0] == coef: # structural equality test
  790. return coef * aa[1]
  791. if (
  792. isinstance(a, sympy.Float)
  793. and a == sympy.Integer(a)
  794. or isinstance(a, sympy.Integer)
  795. ):
  796. return sympy.Integer(a)
  797. return fn(a)
  798. def _sympy_floor(a):
  799. from torch.utils._sympy.functions import FloorToInt
  800. return FloorToInt(a)
  801. # NB: this is Python trunc semantics which returns an int. Do NOT use this to
  802. # represent torch.trunc (which is float to float)
  803. def _sympy_trunc(a):
  804. from torch.utils._sympy.functions import TruncToInt
  805. return TruncToInt(a)
  806. def _sympy_ceil(a):
  807. from torch.utils._sympy.functions import CeilToInt
  808. return CeilToInt(a)
  809. def _sympy_eq(a, b):
  810. import sympy
  811. return sympy.Eq(a, b)
  812. def _sympy_ne(a, b):
  813. import sympy
  814. return sympy.Ne(a, b)
  815. def _sympy_gt(a, b):
  816. import sympy
  817. return sympy.Gt(a, b)
  818. def _sympy_lt(a, b):
  819. import sympy
  820. return sympy.Lt(a, b)
  821. def _sympy_le(a, b):
  822. import sympy
  823. return sympy.Le(a, b)
  824. def _sympy_ge(a, b):
  825. import sympy
  826. return sympy.Ge(a, b)
  827. def _sympy_min(a, b):
  828. from torch.utils._sympy.functions import Min
  829. return Min(a, b)
  830. def _sympy_max(a, b):
  831. from torch.utils._sympy.functions import Max
  832. return Max(a, b)
  833. def _sympy_ite(a, t, f):
  834. import sympy
  835. return sympy.Piecewise((t, a), (f, True))
  836. current_module = sys.modules[__name__]
  837. def _get_sym_math_fn(name):
  838. def fn(a):
  839. import torch.utils._sympy.functions
  840. return getattr(torch.utils._sympy.functions, f"OpaqueUnaryFn_{name}")(a)
  841. return fn
  842. for name in math_op_names:
  843. priv_sympy_name = f"_sympy_{name}"
  844. fn = _get_sym_math_fn(name)
  845. fn.__qualname__ = fn.__name__ = priv_sympy_name
  846. setattr(current_module, priv_sympy_name, fn)
  847. del fn, name, priv_sympy_name # type: ignore[possibly-undefined]
  848. def _sympy_abs(a):
  849. import sympy
  850. return sympy.Abs(a)
  851. def _sympy_round(number, ndigits=None):
  852. from torch.utils._sympy.functions import RoundDecimal, RoundToInt
  853. if ndigits is None:
  854. return RoundToInt(number)
  855. else:
  856. return RoundDecimal(number, ndigits)
  857. def _sympy_sym_float(a):
  858. from torch.utils._sympy.functions import ToFloat
  859. # NB: Cannot use a * 1.0 here, because 0 * 1.0 is 0 which incorrectly
  860. # reports that it is an integer
  861. return ToFloat(a)
  862. def _sympy_is_integer(a):
  863. import sympy
  864. from torch.utils._sympy.functions import ToFloat
  865. return sympy.Eq(ToFloat(sympy.floor(a)), a)
  866. magic_methods = {
  867. **reflectable_magic_methods,
  868. "sym_not": operator.invert,
  869. "pos": operator.pos,
  870. "eq": _sympy_eq,
  871. "ne": _sympy_ne,
  872. "gt": _sympy_gt,
  873. "lt": _sympy_lt,
  874. "le": _sympy_le,
  875. "ge": _sympy_ge,
  876. "floor": _sympy_floor,
  877. "trunc": _sympy_trunc,
  878. "sym_float": _sympy_sym_float,
  879. "ceil": _sympy_ceil,
  880. "neg": operator.neg,
  881. "sym_min": _sympy_min,
  882. "sym_max": _sympy_max,
  883. "sym_ite": _sympy_ite,
  884. "abs": _sympy_abs,
  885. "round": _sympy_round,
  886. "is_integer": _sympy_is_integer,
  887. }
  888. for name in math_op_names:
  889. sym_name = f"sym_{name}"
  890. magic_methods[sym_name] = getattr(current_module, f"_sympy_{name}")
  891. del name, sym_name, math_op_names, current_module # type: ignore[possibly-undefined]
  892. def sympy_is_contiguous(sizes, strides):
  893. dim = len(sizes)
  894. return sympy_is_contiguous_generic(sizes, strides, list(range(dim - 1, -1, -1)))
  895. def sympy_is_contiguous_generic(sizes, strides, dim_order):
  896. import sympy
  897. dim = len(sizes)
  898. if len(dim_order) != dim:
  899. return sympy.false
  900. is_contiguous = sympy.true
  901. z = sympy.S.One
  902. # Contiguous if the strides make sense (or the dim is size 1)
  903. for d in dim_order:
  904. is_contiguous &= sympy.Eq(sizes[d], sympy.S.One) | sympy.Eq(strides[d], z)
  905. z *= sizes[d]
  906. # OR if any size is zero
  907. for d in range(dim):
  908. is_contiguous |= sympy.Eq(sizes[d], sympy.S.Zero)
  909. return is_contiguous
  910. # NB: There is a TODO in C++ to allow omitting the batch dim. If that
  911. # happens you will need to refactor this
  912. def sympy_is_channels_last_contiguous_2d(sizes, strides):
  913. return sympy_is_contiguous_generic(sizes, strides, [1, 3, 2, 0])
  914. def sympy_is_channels_last_contiguous_3d(sizes, strides):
  915. return sympy_is_contiguous_generic(sizes, strides, [1, 4, 3, 2, 0])
  916. def sympy_is_channels_last_strides_generic(sizes, strides, dim_order):
  917. import sympy
  918. from torch.utils._sympy.functions import Max
  919. dim = len(sizes)
  920. if dim != len(dim_order):
  921. return sympy.false
  922. m = sympy.S.Zero
  923. r = sympy.true
  924. # special case for trivial C dimension. default to NCHW
  925. r &= sympy.Ne(strides[1], 0)
  926. for d in dim_order:
  927. r &= sympy.Ne(sizes[d], 0) & (strides[d] >= m)
  928. # Fallback to NCHW as default layout for ambiguous cases
  929. # This is the flaw of implicit memory_format from strides.
  930. # N111 tensor with identical strides for size 1 dimension;
  931. # Two cases could lead us here:
  932. # a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1])
  933. # b. N11W contiguous Tensor sliced on the W-dimension.
  934. # ([N,1,1,1]@[W,W,W,W])
  935. if d == 0:
  936. r &= sympy.Ne(m, strides[1])
  937. # This is necessary to:
  938. # 1. distinguish the memory_format of N1H1;
  939. # [H, 1, 1, 1] channels_last stride
  940. # [H, H, 1, 1] contiguous stride
  941. # 2. permutation of 1C1W:
  942. # [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3)
  943. # [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as
  944. # channels_last
  945. m = strides[d] * Max(sizes[d], 1)
  946. return r
  947. def sympy_is_channels_last_strides_2d(sizes, strides):
  948. return sympy_is_channels_last_strides_generic(sizes, strides, [1, 3, 2, 0])
  949. def sympy_is_channels_last_strides_3d(sizes, strides):
  950. return sympy_is_channels_last_strides_generic(sizes, strides, [1, 4, 3, 2, 0])
  951. def _sympy_is_non_overlapping_and_dense_indicator(sizes, strides):
  952. from torch.utils._sympy.functions import IsNonOverlappingAndDenseIndicator
  953. return IsNonOverlappingAndDenseIndicator(*sizes, *strides)
  954. sizes_strides_methods = {
  955. # TODO: These could also be done with indicators, maybe it is better
  956. # for reasoning to do it that way
  957. "is_contiguous": sympy_is_contiguous,
  958. "is_channels_last_contiguous_2d": sympy_is_channels_last_contiguous_2d,
  959. "is_channels_last_contiguous_3d": sympy_is_channels_last_contiguous_3d,
  960. "is_channels_last_strides_2d": sympy_is_channels_last_strides_2d,
  961. "is_channels_last_strides_3d": sympy_is_channels_last_strides_3d,
  962. "is_non_overlapping_and_dense_indicator": _sympy_is_non_overlapping_and_dense_indicator,
  963. }
  964. def to_node(self, num):
  965. if isinstance(num, SymTypes):
  966. return num.node
  967. elif type(num) is bool:
  968. return self.wrap_bool(num)
  969. elif type(num) is int:
  970. return self.wrap_int(num)
  971. elif type(num) is float:
  972. return self.wrap_float(num)
  973. else:
  974. # NotImplemented is important so that Python tries the
  975. # other magic method
  976. return NotImplemented
  977. def wrap_node(x):
  978. # TODO: let C++ also take advantage of this
  979. if isinstance(x, SymNode) and x.constant is not None:
  980. return x.constant
  981. if x.is_int():
  982. return SymInt(x)
  983. elif x.is_float():
  984. return SymFloat(x)
  985. elif x.is_bool():
  986. return SymBool(x)
  987. else:
  988. raise AssertionError(f"unrecognized return type {x}")
  989. def method_to_operator(method):
  990. return METHOD_TO_OPERATOR[method]
  991. def _make_node_magic(method, func):
  992. func = lru_cache(256)(func)
  993. if method in magic_methods_on_operator_with_trailing_underscore:
  994. method_attr = f"{method}_"
  995. else:
  996. method_attr = method
  997. def uninteresting_files() -> set[str]:
  998. import torch
  999. mods = [
  1000. torch._dynamo.eval_frame,
  1001. torch._dynamo.utils,
  1002. torch.fx.experimental.sym_node,
  1003. torch,
  1004. ]
  1005. import torch._dynamo.guards
  1006. return (
  1007. {inspect.getfile(m) for m in mods}
  1008. | torch._dynamo.guards.uninteresting_files()
  1009. | {"<string>"}
  1010. )
  1011. def capture_provenance(fn):
  1012. @functools.wraps(fn)
  1013. def wrapper(self, other=None):
  1014. if other is None:
  1015. result = fn(self)
  1016. else:
  1017. result = fn(self, other)
  1018. if torch._logging._internal.GET_DTRACE_STRUCTURED:
  1019. if other is not None:
  1020. arguments = [self, other]
  1021. else:
  1022. arguments = [self]
  1023. def get_id(sym_node) -> Optional[int]:
  1024. # We don't want to return an ID if the input is a constant
  1025. import sympy
  1026. if sym_node.constant is not None:
  1027. return None
  1028. elif id(sym_node) == id(result):
  1029. return None
  1030. elif isinstance(sym_node.expr, (sympy.Integer, sympy.Float)):
  1031. return None
  1032. elif sym_node.expr in (sympy.true, sympy.false):
  1033. return None
  1034. return id(sym_node)
  1035. dtrace_structured(
  1036. "expression_created",
  1037. metadata_fn=lambda: {
  1038. "method": method,
  1039. "result": str(result),
  1040. "result_id": id(result),
  1041. "arguments": [str(a) for a in arguments],
  1042. "argument_ids": [
  1043. get_id(i) for i in arguments if get_id(i) is not None
  1044. ],
  1045. "user_stack": structured.get_user_stack(3),
  1046. "stack": structured.get_framework_stack(3),
  1047. },
  1048. )
  1049. return result
  1050. return wrapper
  1051. @capture_provenance
  1052. def binary_magic_impl(self, other):
  1053. from torch.fx.experimental.proxy_tensor import (
  1054. get_proxy_mode,
  1055. handle_sym_dispatch,
  1056. )
  1057. op = method_to_operator(method)
  1058. out_hint: object = _NO_HINT
  1059. if self.hint is not None and other.hint is not None:
  1060. out_hint = op(self.hint, other.hint)
  1061. if get_proxy_mode():
  1062. return to_node(
  1063. self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {})
  1064. )
  1065. if not isinstance(other, SymNode):
  1066. raise AssertionError(f"Expected SymNode, got {type(other)}")
  1067. optimized_summation = False
  1068. try:
  1069. if method == "mod":
  1070. from torch.utils._sympy.functions import Mod, PythonMod
  1071. # Special handling for mod that requires access to the value
  1072. # ranges
  1073. shape_env = self.shape_env
  1074. if (
  1075. self.expr.is_nonnegative
  1076. or shape_env.bound_sympy(self.expr).lower >= 0
  1077. ) and (
  1078. other.expr.is_nonnegative
  1079. or shape_env.bound_sympy(other.expr).lower >= 0
  1080. ):
  1081. out = Mod(self.expr, other.expr)
  1082. else:
  1083. out = PythonMod(self.expr, other.expr)
  1084. elif method == "add":
  1085. # see Note [optimized_summation]
  1086. (optimized_summation, out) = _optimized_add(
  1087. self.expr,
  1088. other.expr,
  1089. self._optimized_summation,
  1090. other._optimized_summation,
  1091. )
  1092. elif method in ("eq", "ne", "ge", "gt", "le", "lt"):
  1093. import sympy
  1094. from torch.utils._sympy.symbol import symbol_is_type, SymT
  1095. # Optimization: when one side is a single unbacked symbol
  1096. # and other is constant, use evaluate=False to skip expensive
  1097. # relational evaluation. We only do this for unbacked symbols
  1098. # because they have no assumptions (like positive=True) that
  1099. # sympy would use during evaluation.
  1100. lhs_is_unbacked = self.expr.is_symbol and symbol_is_type(
  1101. self.expr, SymT.UNBACKED_INT
  1102. )
  1103. rhs_is_unbacked = other.expr.is_symbol and symbol_is_type(
  1104. other.expr, SymT.UNBACKED_INT
  1105. )
  1106. if (lhs_is_unbacked and other.expr.is_number) or (
  1107. rhs_is_unbacked and self.expr.is_number
  1108. ):
  1109. rel_class = {
  1110. "eq": sympy.Eq,
  1111. "ne": sympy.Ne,
  1112. "ge": sympy.Ge,
  1113. "gt": sympy.Gt,
  1114. "le": sympy.Le,
  1115. "lt": sympy.Lt,
  1116. }[method]
  1117. out = rel_class(self.expr, other.expr, evaluate=False)
  1118. else:
  1119. out = func(self.expr, other.expr)
  1120. else:
  1121. # TODO: consider constant prop here
  1122. out = func(self.expr, other.expr)
  1123. except Exception:
  1124. log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr)
  1125. raise
  1126. sym_node_log.debug("%s %s %s -> %s", method, self.expr, other.expr, out)
  1127. pytype: type
  1128. # This is not strictly correct. In Python, a**b may return complex when
  1129. # a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This
  1130. # returns a float while both arguments are ints: 2**(-1). Also, max and
  1131. # min do not type promote. To avoid having data-dependent control flow
  1132. # here, we just set the type to float if one of the args is a float. In
  1133. # case of a type mismatch, we assume that it will be detected during
  1134. # evaluation.
  1135. if method in always_float_magic_methods:
  1136. pytype = float
  1137. elif method in always_bool_magic_methods:
  1138. pytype = bool
  1139. elif self.pytype is float or other.pytype is float:
  1140. pytype = float
  1141. else:
  1142. pytype = self.pytype
  1143. if (
  1144. pytype is not None
  1145. and out_hint is not _NO_HINT
  1146. and out_hint is not None
  1147. and not isinstance(out_hint, SymTypes)
  1148. ):
  1149. out_hint = pytype(out_hint) # type: ignore[arg-type]
  1150. # Create a FX node that corresponds to the operation being applied to
  1151. # this node.
  1152. fx_node, _ = self.shape_env._create_fx_call_function(
  1153. op, (self.fx_node, other.fx_node)
  1154. )
  1155. result = SymNode(
  1156. out,
  1157. self.shape_env,
  1158. pytype,
  1159. out_hint, # type: ignore[arg-type]
  1160. fx_node=fx_node,
  1161. optimized_summation=optimized_summation, # see Note [optimized_summation]
  1162. )
  1163. return result
  1164. @capture_provenance
  1165. def unary_magic_impl(self):
  1166. from torch.fx.experimental.proxy_tensor import (
  1167. get_proxy_mode,
  1168. handle_sym_dispatch,
  1169. )
  1170. op = method_to_operator(method)
  1171. if get_proxy_mode():
  1172. return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {}))
  1173. # TODO: consider constant prop here
  1174. expr = self.expr
  1175. if method == "floor" or method == "ceiling":
  1176. expr = self.shape_env._simplify_floor_div(expr)
  1177. try:
  1178. out = func(expr)
  1179. except Exception:
  1180. log.warning("failed to eval %s(%s)", method, expr)
  1181. raise
  1182. sym_node_log.debug("%s %s -> %s", func, expr, out)
  1183. out_hint: object = _NO_HINT
  1184. if self.hint is not None:
  1185. out_hint = op(self.hint)
  1186. pytype: type
  1187. if method in always_int_magic_methods:
  1188. pytype = int
  1189. elif method in always_bool_magic_methods:
  1190. pytype = bool
  1191. elif method in always_float_magic_methods:
  1192. pytype = float
  1193. else:
  1194. pytype = self.pytype
  1195. fx_node, _ = self.shape_env._create_fx_call_function(op, (self.fx_node,))
  1196. return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) # type: ignore[arg-type]
  1197. if method in unary_methods:
  1198. setattr(SymNode, f"_{method_attr}", unary_magic_impl)
  1199. elif method == "sym_ite":
  1200. def sym_ite_impl(pred_node, then_node, else_node):
  1201. from torch.fx.experimental.proxy_tensor import (
  1202. get_proxy_mode,
  1203. handle_sym_dispatch,
  1204. )
  1205. out_hint = then_node.hint if pred_node.hint else else_node.hint
  1206. if get_proxy_mode():
  1207. return to_node(
  1208. pred_node,
  1209. handle_sym_dispatch(
  1210. sym_ite,
  1211. (
  1212. wrap_node(pred_node),
  1213. wrap_node(then_node),
  1214. wrap_node(else_node),
  1215. ),
  1216. {},
  1217. ),
  1218. )
  1219. try:
  1220. out = func(pred_node.expr, then_node.expr, else_node.expr)
  1221. except Exception:
  1222. log.warning(
  1223. "failed to eval %s(%s, %s, %s)",
  1224. method,
  1225. pred_node.expr,
  1226. then_node.expr,
  1227. else_node.expr,
  1228. )
  1229. raise
  1230. fx_node, _ = pred_node.shape_env._create_fx_call_function(
  1231. sym_ite, (pred_node.fx_node, then_node.fx_node, else_node.fx_node)
  1232. )
  1233. return SymNode(
  1234. out, pred_node.shape_env, then_node.pytype, out_hint, fx_node=fx_node
  1235. )
  1236. setattr(SymNode, f"_{method_attr}", sym_ite_impl)
  1237. elif method == "round":
  1238. def round_impl(self, ndigits=None):
  1239. from torch.fx.experimental.proxy_tensor import (
  1240. get_proxy_mode,
  1241. handle_sym_dispatch,
  1242. )
  1243. op = builtins.round
  1244. if get_proxy_mode():
  1245. return to_node(
  1246. self, handle_sym_dispatch(op, (wrap_node(self), ndigits), {})
  1247. )
  1248. expr = self.expr
  1249. try:
  1250. out = func(expr, ndigits)
  1251. except Exception:
  1252. log.warning("failed to eval %s(%s, ndigits=%s)", method, expr, ndigits)
  1253. raise
  1254. if ndigits is None:
  1255. pytype = int
  1256. else:
  1257. pytype = self.pytype
  1258. out_hint = None
  1259. if self.hint is not None:
  1260. out_hint = op(self.hint, ndigits)
  1261. # Internally, None is used as sentinel to indicate that a something is not a node on an FX graph. At the
  1262. # same time, there is no way to wrap a plain None into an FX node. Thus, there is no way to pass None here
  1263. # without triggering some asserts that check whether we are mixing FX nodes with untracked arguments. The
  1264. # hack down below works, because all round function down the line all take ndigits=None as default in their
  1265. # signature.
  1266. # TODO: Remove the args construction below if a different sentinel is used by FX.
  1267. # ezyang(May 2024): LOL
  1268. args = [self.fx_node]
  1269. if ndigits is not None:
  1270. args.append(ndigits)
  1271. fx_node, _ = self.shape_env._create_fx_call_function(op, tuple(args))
  1272. return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node)
  1273. setattr(SymNode, f"_{method_attr}", round_impl)
  1274. else:
  1275. setattr(SymNode, f"_{method_attr}", binary_magic_impl)
  1276. def _make_node_sizes_strides(method, func):
  1277. # NB: don't LRU cache, lots of arguments
  1278. def sizes_strides_impl(self, sizes, strides):
  1279. from torch.fx.experimental.proxy_tensor import (
  1280. get_proxy_mode,
  1281. handle_sym_dispatch,
  1282. )
  1283. op = getattr(sys.modules[__name__], method)
  1284. if get_proxy_mode():
  1285. return to_node(
  1286. self,
  1287. handle_sym_dispatch(
  1288. op,
  1289. ([wrap_node(s) for s in sizes], [wrap_node(s) for s in strides]),
  1290. {},
  1291. ),
  1292. )
  1293. size_exprs = [s.expr for s in sizes]
  1294. stride_exprs = [s.expr for s in strides]
  1295. try:
  1296. out = func(size_exprs, stride_exprs)
  1297. except Exception:
  1298. log.warning("failed to eval %s(%s, %s)", method, size_exprs, stride_exprs)
  1299. raise
  1300. # bool is never expandable
  1301. size_hints = []
  1302. out_hint = None
  1303. for s in sizes:
  1304. if s.hint is None:
  1305. break
  1306. size_hints.append(s.hint)
  1307. else:
  1308. stride_hints = []
  1309. for s in strides:
  1310. if s.hint is None:
  1311. break
  1312. stride_hints.append(s.hint)
  1313. else:
  1314. out_hint = op(size_hints, stride_hints)
  1315. # NB: This is the indicator function, not the actual bool!
  1316. pytype: type
  1317. if method.endswith("_indicator"):
  1318. pytype = int
  1319. else:
  1320. pytype = bool
  1321. return SymNode(out, self.shape_env, pytype, out_hint)
  1322. setattr(SymNode, f"_{method}", sizes_strides_impl)
  1323. # TODO: This is technically hotpath, but in the ideal end state
  1324. # guards on this will resolve at a higher level so you never
  1325. # spend time in this code
  1326. def sizes_strides_user(sizes, strides):
  1327. import sympy
  1328. from torch.fx.experimental.symbolic_shapes import (
  1329. eval_is_non_overlapping_and_dense,
  1330. )
  1331. for a in itertools.chain(sizes, strides):
  1332. if isinstance(a, SymInt):
  1333. return wrap_node(
  1334. getattr(a.node, method)(
  1335. [to_node(a.node, b) for b in sizes],
  1336. [to_node(a.node, b) for b in strides],
  1337. )
  1338. )
  1339. if method == "is_non_overlapping_and_dense_indicator":
  1340. return eval_is_non_overlapping_and_dense(sizes, strides)
  1341. else:
  1342. # TODO: this is an awful implementation
  1343. return bool(
  1344. func(
  1345. [sympy.sympify(a) for a in sizes],
  1346. [sympy.sympify(a) for a in strides],
  1347. )
  1348. )
  1349. # Skip for is_non_overlapping_and_dense_indicator
  1350. if not hasattr(sys.modules[__name__], method):
  1351. setattr(sys.modules[__name__], method, sizes_strides_user)
  1352. for method, func in magic_methods.items():
  1353. _make_node_magic(method, func)
  1354. for method, func in sizes_strides_methods.items():
  1355. _make_node_sizes_strides(method, func)
  1356. def _make_user_magic(method, user_type):
  1357. # User magic takes care of wrapping the other operand into a node,
  1358. # so that our internal logic can assume everything is nodes
  1359. if method in magic_methods_on_operator_with_trailing_underscore:
  1360. method_attr = f"sym_{method}"
  1361. else:
  1362. method_attr = method
  1363. def get_constant(x: Union[SymInt, int, SymFloat, float, SymBool, bool]):
  1364. if isinstance(x, (int, float, bool)):
  1365. return x
  1366. if isinstance(x, SymInt):
  1367. return x.node.guard_int("", 0)
  1368. if isinstance(x, SymBool):
  1369. return x.node.guard_bool("", 0)
  1370. raise AssertionError("expect to be called with constant SymBools")
  1371. def is_constant(x):
  1372. if isinstance(x, (int, float, bool)):
  1373. return True
  1374. if isinstance(x, (SymInt, SymFloat, SymBool)):
  1375. return x.node.is_constant()
  1376. return False
  1377. # Promotion rules for binary operations. NB: we preserve PYTHON semantics
  1378. # - if args are same type, do nothing
  1379. # - if one arg is float, promote other arg to float
  1380. # - nb: this applies to floordiv, even though output is integral
  1381. # (it's still float)
  1382. # - pow is funny business
  1383. # - if both ints
  1384. # - trigger a guard on exponent >= 0
  1385. # - if non-negative, output is int
  1386. # - otherwise, output is float
  1387. # - otherwise, promote other arg to float
  1388. # - nb: complex is impossible to handle correctly lol, with
  1389. # negative base and integral float need to diverge semantics and
  1390. # just always return complex. Neener neener pretend this problem
  1391. # doesn't exist
  1392. # - equality is pain: Python does the fancy thing where it unpacks the
  1393. # mantissa from the float and then compares that against the int.
  1394. # Which means it is able to tell that
  1395. # 9007199254740993 != 9007199254740992. (rather than if the LHS was
  1396. # promoted to float, in which case it would have truncated to the RHS
  1397. # and subsequently been equal). We'll model this exactly by having
  1398. # special mixed type equality operations. Unfortunately, we need to
  1399. # do this for all comparison operations (maybe I'll only implement
  1400. # compare)
  1401. # - sym_ite mumble mumble really shouldn't allow mixed but whatever
  1402. if method in bool_becomes_int_magic_methods:
  1403. def promote(x):
  1404. """Implements True+True=2, which works in python but not sympy"""
  1405. if isinstance(x, SymBool):
  1406. return SymInt(x.node.wrap_int(int(x)))
  1407. return x
  1408. else:
  1409. def promote(x):
  1410. return x
  1411. def promote2(self, other):
  1412. # TODO: Remove eq and other relations from this list.
  1413. # CPython has fancy implementations for these to get as much precision
  1414. # as possible instead of just promoting to float64 and praying, so we
  1415. # need to handle them specially too.
  1416. # Also, note that int_truediv doesn't go through this path: both
  1417. # arguments are "int" so there isn't any promotion
  1418. if method not in [
  1419. "add",
  1420. "sub",
  1421. "mul",
  1422. "mod",
  1423. "float_pow",
  1424. "float_truediv",
  1425. "int_floordiv",
  1426. "sym_min",
  1427. "sym_max",
  1428. # TODO: remove these
  1429. "eq",
  1430. "ne",
  1431. "gt",
  1432. "lt",
  1433. "le",
  1434. "ge",
  1435. ]:
  1436. return self, other
  1437. f_self = isinstance(self, (float, torch.SymFloat))
  1438. f_other = isinstance(other, (float, torch.SymFloat))
  1439. if f_self or f_other:
  1440. if not f_self:
  1441. self = torch.sym_float(self)
  1442. if not f_other:
  1443. other = torch.sym_float(other)
  1444. return self, other
  1445. # Before and after performing the operation, check if any operands are constant.
  1446. # If so, extract out the constant values first. If `self` itself is a
  1447. # constant, then "redispatch" by calling back into the operator. Sometimes
  1448. # this means that operations involving SymBool return plain bools.
  1449. # Alternatively, we could also rewrap into constant Symbool (i.e. by
  1450. # implementing wrap_bool in ConstantSymNodeImpl), but we're not doing that
  1451. # today for no particular reason.
  1452. def unary_magic_impl(self):
  1453. self = promote(self)
  1454. if is_constant(self):
  1455. return (method_to_operator(method))(get_constant(self))
  1456. return wrap_node(getattr(self.node, method_attr)())
  1457. def binary_magic_impl(self, other):
  1458. if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)):
  1459. return NotImplemented
  1460. sym_node_log.debug("MAGIC %s %s %s", method, self, other)
  1461. self = promote(self)
  1462. other = promote(other)
  1463. self, other = promote2(self, other)
  1464. if is_constant(self):
  1465. return (method_to_operator(method))(get_constant(self), other)
  1466. if is_constant(other):
  1467. other = get_constant(other)
  1468. other_node = to_node(self.node, other)
  1469. if other_node is NotImplemented:
  1470. return NotImplemented
  1471. ret = wrap_node(getattr(self.node, method_attr)(other_node))
  1472. return get_constant(ret) if is_constant(ret) else ret
  1473. def rbinary_magic_impl(self, other):
  1474. if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)):
  1475. return NotImplemented
  1476. self = promote(self)
  1477. other = promote(other)
  1478. self, other = promote2(self, other)
  1479. if is_constant(self):
  1480. return (method_to_operator(method))(other, get_constant(self))
  1481. if is_constant(other):
  1482. other = get_constant(other)
  1483. other_node = to_node(self.node, other)
  1484. if other_node is NotImplemented:
  1485. return NotImplemented
  1486. ret = wrap_node(getattr(other_node, method_attr)(self.node))
  1487. return get_constant(ret) if is_constant(ret) else ret
  1488. def setattrs(user_type, attr, symnode_impl):
  1489. """
  1490. Registers the SymNode magic method on SymInt/Float/Bool,
  1491. and optionally registers a corresponding wrapped method on DynamicInt.
  1492. """
  1493. # SymInt/Float/Bool
  1494. setattr(user_type, attr, symnode_impl)
  1495. # DynamicInt impl
  1496. def dynamic_int_impl(*args):
  1497. args = [x.real if isinstance(x, DynamicInt) else x for x in args]
  1498. out = getattr(int, attr)(*args)
  1499. if isinstance(out, int) and not isinstance(out, bool):
  1500. return DynamicInt(out)
  1501. return out
  1502. if user_type is SymInt:
  1503. setattr(DynamicInt, attr, dynamic_int_impl)
  1504. if method in unary_magic_methods:
  1505. setattrs(user_type, f"__{method}__", unary_magic_impl)
  1506. elif method in unary_nonmagic_methods:
  1507. orig = getattr(user_type, method)
  1508. setattrs(user_type, method, update_wrapper(unary_magic_impl, orig))
  1509. elif method == "sym_ite":
  1510. def sym_ite_magic_impl(pred, then_val, else_val):
  1511. pred_node = pred.node
  1512. then_node = to_node(pred_node, then_val)
  1513. else_node = to_node(pred_node, else_val)
  1514. if then_node is NotImplemented or else_node is NotImplemented:
  1515. return NotImplemented
  1516. if not (
  1517. isinstance(then_node, SymNode)
  1518. and isinstance(else_node, SymNode)
  1519. and then_node.pytype == else_node.pytype
  1520. ):
  1521. raise AssertionError(
  1522. "then_node and else_node must be SymNodes with same pytype"
  1523. )
  1524. ret = wrap_node(getattr(pred.node, method_attr)(then_node, else_node))
  1525. return get_constant(ret) if ret.node.is_constant() else ret
  1526. setattrs(user_type, f"__{method}__", sym_ite_magic_impl)
  1527. elif method == "round":
  1528. def round_magic_impl(self, ndigits=None):
  1529. if is_constant(self):
  1530. return builtins.round(get_constant(self), ndigits)
  1531. return wrap_node(getattr(self.node, method)(ndigits))
  1532. setattrs(user_type, f"__{method}__", round_magic_impl)
  1533. else:
  1534. method_name = method
  1535. if method in bitwise_ops:
  1536. method_name = bitwise_ops[method]
  1537. setattrs(user_type, f"__{method_name}__", binary_magic_impl)
  1538. if method in reflectable_magic_methods:
  1539. setattrs(user_type, f"__r{method_name}__", rbinary_magic_impl)
  1540. for method in magic_methods: # type: ignore[assignment]
  1541. if method in only_bool_magic_methods:
  1542. _make_user_magic(method, SymBool)
  1543. continue
  1544. if method in only_float_magic_methods:
  1545. _make_user_magic(method, SymFloat)
  1546. continue
  1547. if method in also_bool_magic_methods or method in bool_becomes_int_magic_methods:
  1548. _make_user_magic(method, SymBool)
  1549. _make_user_magic(method, SymInt)
  1550. if method not in bitwise_ops:
  1551. _make_user_magic(method, SymFloat)
  1552. del method
  1553. del func