runtime_assert.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643
  1. # mypy: allow-untyped-defs
  2. import functools
  3. import logging
  4. import operator
  5. import sys
  6. from typing import Any, Optional, TYPE_CHECKING
  7. # Import sympy and ShapeEnv during TYPE_CHECKING since importing sympy is slow
  8. if TYPE_CHECKING:
  9. import sympy
  10. from torch.fx.experimental.symbolic_shapes import ShapeEnv
  11. else:
  12. ShapeEnv = Any
  13. import torch
  14. import torch.utils._pytree as pytree
  15. from torch import fx
  16. from torch._subclasses.meta_utils import is_sparse_any
  17. from torch.fx._compatibility import compatibility
  18. from torch.fx._utils import lazy_format_graph_code
  19. from torch.fx.experimental.proxy_tensor import py_sym_types
  20. from torch.fx.experimental.sym_node import SymNode
  21. from torch.fx.graph_module import GraphModule
  22. __all__ = ["insert_deferred_runtime_asserts"]
  23. log = logging.getLogger(__name__)
  24. graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code_verbose")
  25. def _get_example_value(node: fx.Node) -> Optional[str]:
  26. """
  27. Get the example value key for a node, since dynamo uses "example_value"
  28. while non-strict export uses "val.
  29. """
  30. if "example_value" in node.meta:
  31. return node.meta["example_value"]
  32. elif "val" in node.meta:
  33. return node.meta["val"]
  34. else:
  35. return None
  36. def _get_sym_val(node: fx.Node) -> Optional["sympy.Expr"]:
  37. val = _get_example_value(node)
  38. if isinstance(val, py_sym_types):
  39. return val.node.expr
  40. return None
  41. @compatibility(is_backward_compatible=True)
  42. def insert_deferred_runtime_asserts(
  43. gm: GraphModule,
  44. shape_env: ShapeEnv,
  45. name: str,
  46. export: bool = False,
  47. ) -> None:
  48. """
  49. During tracing, we may have discovered that some data-dependent values
  50. had runtime assert on them; e.g., torch.empty(x.item()) induces a runtime
  51. that x.item() >= 0. These asserts can happen unpredictably during fake
  52. tensor propagation, so we cannot conveniently insert them into the FX graph
  53. when they occur. Instead, we accumulate them in the ShapeEnv, and in this
  54. pass insert them into the graph as proper tests.
  55. This pass also deduplicates size-related computation, CSE-ing ops that produce
  56. symbolic values and/or are involved in runtime asserts. Additionally, shape calls
  57. (size/stride/storage_offset) are turned into compute on input sizes if possible,
  58. allowing intermediate tensors to be freed earlier. For example, here dynamo will
  59. DCE the cat and repeat calls:
  60. z = torch.cat([x, x], dim=0) # 2*s0
  61. w = z.repeat(y.shape[0]) # 2*s0*s1
  62. _w = w.shape[0]
  63. # something with _w, but not w ...
  64. # turns into ->
  65. _w0 = 2 * s0
  66. _w = _w0 * s1
  67. # where s0, s1 are either SymInt graph inputs, or the result of added size calls
  68. Redundant torch._check or torch.ops.aten._assert_scalar.default calls that assert
  69. the same expression, and redundant constrain_range calls are also deduplicated.
  70. Additionally, because single-symbol bound checks (e.g. u0 >= 0, u0 <= 5) accumulate
  71. information in the ShapeEnv, the ShapeEnv contains min/max bounds for each symbol,
  72. and we delete all previous calls, adding bound checks at the end of this pass.
  73. """
  74. # Import sympy locally
  75. import sympy
  76. from torch._export.passes._node_metadata_hook import _set_node_metadata_hook
  77. from torch.fx.experimental.symbolic_shapes import (
  78. _get_placeholder_expr,
  79. _has_uninterpretable_sympy_function,
  80. CallMethodKey,
  81. cast_symbool_to_symint_guardless,
  82. ConvertIntKey,
  83. DivideByKey,
  84. free_symbols,
  85. InnerTensorKey,
  86. resolve_unbacked_bindings,
  87. )
  88. from torch.utils._sympy.numbers import int_oo
  89. from torch.utils._sympy.reference import (
  90. OptimizedPythonReferenceAnalysis,
  91. PythonReferenceAnalysis,
  92. )
  93. from torch.utils._sympy.value_ranges import ValueRanges
  94. # TODO: Request simplification on runtime asserts before emitting them
  95. ras_by_symbol = shape_env.deferred_runtime_asserts.copy()
  96. graph = gm.graph
  97. tracer = fx.proxy.GraphAppendingTracer(graph)
  98. graph_code_log.debug(
  99. "%s",
  100. lazy_format_graph_code(
  101. f"pre insert_deferred_runtime_asserts {name}", gm, colored=True
  102. ),
  103. )
  104. # We are going to mutate the dict
  105. expr_to_proxy: dict[sympy.Expr, fx.Proxy] = {}
  106. placeholders = set()
  107. first_non_placeholder = None
  108. for node in graph.nodes:
  109. if node.op != "placeholder":
  110. first_non_placeholder = node
  111. break
  112. else:
  113. placeholders.add(node)
  114. def _is_intermediate_tensor_sym_call(node: fx.Node) -> bool:
  115. """
  116. If a size/stride/storage offset call on an intermediate tensor,
  117. we can try to compute the value from input shapes instead.
  118. """
  119. return (
  120. (val := _get_sym_val(node)) is not None
  121. and not isinstance(val, sympy.Number)
  122. # this holds back from reifying anything in torch.utils._sympy.functions.py that's unsupported
  123. and not _has_uninterpretable_sympy_function(val)
  124. and any(
  125. isinstance(arg, fx.Node)
  126. and isinstance(_get_example_value(arg), (torch.Tensor, torch.Size))
  127. and arg.op != "placeholder"
  128. for arg in node.args
  129. )
  130. )
  131. # Figure out what key to use, val or example_value
  132. val_key = "val"
  133. for node in graph.nodes:
  134. if "example_value" in node.meta:
  135. val_key = "example_value"
  136. break
  137. elif "val" in node.meta:
  138. break
  139. def _node_metadata_hook(
  140. node: torch.fx.Node,
  141. stack_trace: Optional[str] = None,
  142. nn_module_stack: Optional[dict[str, Any]] = None,
  143. custom: Optional[dict[str, Any]] = None,
  144. ) -> None:
  145. fake_args = pytree.tree_map(
  146. lambda arg: (
  147. _get_example_value(arg) if isinstance(arg, torch.fx.Node) else arg
  148. ),
  149. node.args,
  150. )
  151. try:
  152. target = node.target
  153. if node.op == "call_method":
  154. if not isinstance(node.target, str):
  155. raise AssertionError(
  156. f"Expected str target, got {type(node.target)}"
  157. )
  158. target = getattr(fake_args[0], node.target)
  159. fake_args = fake_args[1:]
  160. node.meta[val_key] = target(*fake_args) # type: ignore[operator]
  161. except NotImplementedError:
  162. # This can happen when attempting to reify a symbol with an unsupported call_function node,
  163. # e.g. with NestedTensors + sym_size.int via match_symbol().
  164. # This seems to be fine, as the node gets CSE'd and deleted later in favor of a SymInt graph input.
  165. pass
  166. except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
  167. # This can happen when node args are symints
  168. # e.g. test/dynamo/test_export.py -k test_export_preserve_constraints_as_metadata_tensor
  169. # aten.sym_constrain_range_for_size(u0)
  170. pass
  171. if stack_trace is not None:
  172. node.meta["stack_trace"] = stack_trace
  173. if nn_module_stack is not None:
  174. node.meta["nn_module_stack"] = nn_module_stack
  175. if custom is not None:
  176. node.meta["custom"] = custom
  177. # Track asserts/checks we've added
  178. added_asserts: set[sympy.Expr] = set()
  179. constrained_unbacked_symbols: set[sympy.Symbol] = set()
  180. Analysis = PythonReferenceAnalysis if export else OptimizedPythonReferenceAnalysis
  181. def _sympy_interp(expr_to_proxy, expr):
  182. # sympy_interp() with hash consing
  183. from sympy import Integer, Number, Symbol
  184. from sympy.logic.boolalg import BooleanAtom
  185. from torch.utils._sympy.interp import _run_sympy_handler, sympy_interp
  186. # hash cons
  187. if expr in expr_to_proxy:
  188. return expr_to_proxy[expr]
  189. # base cases, don't cache
  190. if isinstance(expr, (Integer, Number, Symbol, BooleanAtom)):
  191. return sympy_interp(Analysis, expr_to_proxy, expr)
  192. # hash cons on arguments, run expr handler
  193. expr_to_proxy[expr] = _run_sympy_handler(
  194. Analysis,
  195. [_sympy_interp(expr_to_proxy, arg) for arg in expr.args],
  196. expr,
  197. )
  198. return expr_to_proxy[expr]
  199. def _is_bound_expr_for_symbol(expr: "sympy.Expr") -> bool:
  200. # This is probably unnecessary, but since torch._check() calls for single-symbol bounds
  201. # like u0 >= 0, 10 >= u0 accumulate range info in the ShapeEnv, we designate these calls as redundant
  202. # and instead add 2 runtime asserts at the end of this pass, if the min/max bounds are non-trivial.
  203. if len(expr.args) != 2 or expr.func not in (sympy.LessThan, sympy.GreaterThan):
  204. return False
  205. lhs, rhs = expr.args
  206. return (isinstance(lhs, sympy.Symbol) and isinstance(rhs, sympy.Number)) or (
  207. isinstance(rhs, sympy.Symbol) and isinstance(lhs, sympy.Number)
  208. )
  209. def add_runtime_asserts(ras):
  210. for ra in ras:
  211. if (
  212. # redundant
  213. ra.expr in added_asserts
  214. # if we've already added a constrain_range call for this symbol,
  215. # then single-symbol bound asserts like u0 >= 0, u0 <= 5 are redundant.
  216. or (
  217. len(ra.expr.free_symbols) == 1
  218. and next(iter(ra.expr.free_symbols)) in constrained_unbacked_symbols
  219. and _is_bound_expr_for_symbol(ra.expr)
  220. )
  221. # don't try to reify sympy functions we can't turn into FX nodes
  222. or _has_uninterpretable_sympy_function(ra.expr)
  223. ):
  224. continue
  225. log.debug("inserting runtime assert %s", ra.expr)
  226. # Need to process ALL free symbols, not just unbacked ones
  227. fvs = free_symbols(ra.expr)
  228. missing = fvs - expr_to_proxy.keys()
  229. if missing:
  230. i1 = min(missing, key=str)
  231. # TODO: Remove relaxing assert on unbacked_symint https://github.com/pytorch/pytorch/issues/119689
  232. # assert shape_env.is_unbacked_symint(i1), i1
  233. ras_by_symbol.setdefault(i1, []).append(ra)
  234. else:
  235. # Convert the sympy expression into a sequence of FX
  236. # nodes
  237. res = _sympy_interp(expr_to_proxy, ra.expr).node
  238. graph.call_function(
  239. torch.ops.aten._assert_scalar.default,
  240. # TODO: use ra.msg here, but it's pretty
  241. # useless right now
  242. (
  243. res,
  244. f"Runtime assertion failed for expression {ra.expr} on node '{res}'",
  245. ),
  246. )
  247. added_asserts.add(ra.expr)
  248. nodes = list(graph.nodes)
  249. for i, node in enumerate(nodes[:-1]):
  250. # Placeholders can match symbols, but when we destructure them
  251. # with size we have to make sure we insert the nodes after all
  252. # the placeholders
  253. with (
  254. graph.inserting_before(
  255. nodes[i + 1] if node not in placeholders else first_non_placeholder
  256. ),
  257. _set_node_metadata_hook(
  258. gm,
  259. functools.partial(
  260. _node_metadata_hook,
  261. stack_trace=node.meta.get("stack_trace"),
  262. nn_module_stack=node.meta.get("nn_module_stack"),
  263. custom=node.meta.get("custom"),
  264. ),
  265. ),
  266. ):
  267. # Unfortunately, this logic still must remain because manual
  268. # make_fx calls may not explicitly bind all symbolic ints as
  269. # arguments to the function, so we must infer it from the other
  270. # arguments
  271. if (
  272. node in placeholders
  273. and (example_value := _get_example_value(node)) is not None
  274. ):
  275. def match_symbol(symint, cb):
  276. if (
  277. isinstance(symint, torch.SymInt)
  278. and isinstance(symint.node, SymNode)
  279. and isinstance(
  280. s := _get_placeholder_expr(symint.node), sympy.Symbol
  281. )
  282. and s not in expr_to_proxy
  283. ):
  284. expr_to_proxy[s] = fx.Proxy(cb(), tracer=tracer)
  285. log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s])
  286. match_symbol(example_value, lambda: node)
  287. if isinstance(t := example_value, torch.Tensor):
  288. for i, s in enumerate(t.size()):
  289. match_symbol(
  290. s,
  291. lambda: graph.call_function(
  292. torch.ops.aten.sym_size.int, (node, i)
  293. ),
  294. )
  295. if not is_sparse_any(t):
  296. for i, s in enumerate(t.stride()):
  297. match_symbol(
  298. s,
  299. lambda: graph.call_function(
  300. torch.ops.aten.sym_stride.int, (node, i)
  301. ),
  302. )
  303. match_symbol(
  304. t.storage_offset(),
  305. lambda: graph.call_function(
  306. torch.ops.aten.sym_storage_offset.default, (node,)
  307. ),
  308. )
  309. # Handle asserts that aren't associated with any symbol. This
  310. # doesn't really have to be in the loop as it will only run once,
  311. # it just needs to happen right after the placeholders.
  312. # insert this after placeholders & added sym nodes, and before non-placeholders.
  313. if node == first_non_placeholder:
  314. add_runtime_asserts(ras_by_symbol.pop(None, [])) # type: ignore[call-overload]
  315. # deduplicate asserts already present in graph, and remove trivial asserts
  316. if node.target in (
  317. torch._check,
  318. torch.ops.aten._assert_scalar.default,
  319. ):
  320. cond = node.args[0] if node.args else node.kwargs.get("cond")
  321. if (
  322. cond == True # noqa: E712
  323. or (assert_expr := _get_sym_val(cond)) in expr_to_proxy
  324. and assert_expr in added_asserts
  325. ):
  326. arg = cond
  327. gm.graph.erase_node(node)
  328. if isinstance(arg, fx.Node) and not arg.users:
  329. gm.graph.erase_node(arg)
  330. else:
  331. added_asserts.add(assert_expr) # type: ignore[arg-type]
  332. # hash cons, replace function calls that return torch.SymInts with direct references to
  333. # FX nodes built up to reify the sympy expression.
  334. if (
  335. node.op != "placeholder"
  336. and (sym_expr := _get_sym_val(node)) is not None
  337. ):
  338. # this guards against deleting calls like item() that produce new untracked symbols
  339. def has_new_untracked_symbols():
  340. # pyrefly: ignore [missing-attribute]
  341. for symbol in sym_expr.free_symbols:
  342. if symbol not in expr_to_proxy:
  343. return True
  344. return False
  345. # this guards against deleting calls that produce unbacked bindings we haven't yet seen.
  346. # in this case looking at sym_expr.free_symbols might not be enough, if the example value has a hint
  347. # (is backed), but produces an unbacked symbol. In this case keep the node alive.
  348. resolved_unbacked_bindings = resolve_unbacked_bindings(
  349. shape_env, node.meta.get("unbacked_bindings", {})
  350. )
  351. def has_new_unbacked_bindings():
  352. if resolved_unbacked_bindings is None:
  353. raise AssertionError("resolved_unbacked_bindings is None")
  354. for key in resolved_unbacked_bindings:
  355. if key not in expr_to_proxy:
  356. return True
  357. return False
  358. # maybe re-reify expression, replace current node
  359. if (
  360. sym_expr in expr_to_proxy
  361. or ( # example value is redundant
  362. _is_intermediate_tensor_sym_call(node)
  363. # shape call on intermediate tensor, turn into computation on input shapes
  364. and not has_new_untracked_symbols()
  365. )
  366. ) and not has_new_unbacked_bindings():
  367. if _is_intermediate_tensor_sym_call(
  368. node
  369. ): # reify from input shapes
  370. expr_to_proxy[sym_expr] = _sympy_interp(
  371. expr_to_proxy,
  372. sym_expr,
  373. ) # type: ignore[arg-type]
  374. # won't try DCE-ing tensor compute here
  375. hash_node = expr_to_proxy[sym_expr].node # type: ignore[arg-type]
  376. node.replace_all_uses_with(hash_node)
  377. gm.graph.erase_node(node)
  378. log.debug(
  379. "CSE node %s -> %s for expr %s",
  380. node,
  381. hash_node,
  382. sym_expr,
  383. )
  384. # store node in hash cons, don't delete/replace
  385. elif sym_expr not in expr_to_proxy and not isinstance(
  386. sym_expr,
  387. (sympy.Number, sympy.logic.boolalg.BooleanAtom),
  388. ): # don't hash cons primitives
  389. expr_to_proxy[sym_expr] = fx.Proxy(node, tracer=tracer) # type: ignore[arg-type]
  390. # We add sym_constrain_range calls for symbols later in any case if they're size-like or range-constrained,
  391. # so calls before that are redundant.
  392. if node.target in (
  393. torch.ops.aten.sym_constrain_range.default,
  394. torch.ops.aten.sym_constrain_range_for_size.default,
  395. ):
  396. gm.graph.erase_node(node)
  397. defs = []
  398. # AOTAutograd will create new symbols as the unbacked_bindings keys, which PropagateSymInts will set as
  399. # equivalent, but the refinement calls we perform in this pass may struggle with associating the two.
  400. # More concretely, when re-exporting/tracing, constraining only the new symbol may not communicate enough
  401. # information about the old symbol when we re-export, raising errors on data-dependent guards.
  402. # Call resolve_unbacked_bindings() to get the original symbol if present, otherwise we take it as is.
  403. if unbacked_bindings := resolve_unbacked_bindings(
  404. shape_env, node.meta.get("unbacked_bindings")
  405. ):
  406. for s, keypath in unbacked_bindings.items():
  407. defs.append(s)
  408. # TODO: some CSE when generating these nodes can probably
  409. # help reduce graph size and improve compile time
  410. def go(node, keypath):
  411. if keypath == ():
  412. return node
  413. if (
  414. len(keypath) >= 2
  415. and isinstance(keypath[0], CallMethodKey)
  416. and isinstance(keypath[1], pytree.SequenceKey)
  417. ):
  418. if keypath[0].name == "size":
  419. return go(
  420. graph.call_function(
  421. torch.ops.aten.sym_size.int,
  422. (node, keypath[1].idx),
  423. ),
  424. keypath[2:],
  425. )
  426. if keypath[0].name == "stride":
  427. return go(
  428. graph.call_function(
  429. torch.ops.aten.sym_stride.int,
  430. (node, keypath[1].idx),
  431. ),
  432. keypath[2:],
  433. )
  434. return go(
  435. graph.call_method(
  436. keypath[0].name, (node, keypath[1].idx)
  437. ),
  438. keypath[2:],
  439. )
  440. elif isinstance(keypath[0], CallMethodKey):
  441. if keypath[0].name == "storage_offset":
  442. return go(
  443. graph.call_function(
  444. torch.ops.aten.sym_storage_offset.default,
  445. (node,),
  446. ),
  447. keypath[1:],
  448. )
  449. return go(
  450. graph.call_method(keypath[0].name, (node,)), keypath[1:]
  451. )
  452. elif isinstance(keypath[0], pytree.SequenceKey):
  453. return go(
  454. graph.call_function(
  455. operator.getitem, (node, keypath[0].idx)
  456. ),
  457. keypath[1:],
  458. )
  459. elif isinstance(keypath[0], ConvertIntKey):
  460. return go(
  461. graph.call_function(
  462. cast_symbool_to_symint_guardless, (node,)
  463. ),
  464. keypath[1:],
  465. )
  466. elif isinstance(keypath[0], DivideByKey):
  467. # TODO: need to assert divisibility
  468. return go(
  469. graph.call_function(
  470. operator.floordiv, (node, keypath[0].divisor)
  471. ),
  472. keypath[1:],
  473. )
  474. elif isinstance(keypath[0], InnerTensorKey):
  475. return go(
  476. graph.call_function(
  477. getattr, (node, keypath[0].inner_name)
  478. ),
  479. keypath[1:],
  480. )
  481. else:
  482. raise AssertionError(f"unrecognized keypath {keypath}")
  483. if s not in expr_to_proxy:
  484. expr_to_proxy[s] = fx.Proxy(go(node, keypath), tracer=tracer)
  485. log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s])
  486. for i0 in defs:
  487. ras = ras_by_symbol.pop(i0, [])
  488. # Before we perform any asserts, first apply range
  489. # refinement. This is important, because if we are going
  490. # to retrace the graph (and we typically are if we send
  491. # the graph to AOTAutograd), we need to make sure we apply
  492. # range refinement (ala _check_is_size) first, BEFORE we
  493. # run any of the asserts. Otherwise, we may decide to
  494. # perform substitutions based on the asserts which we then
  495. # can't back out, because value ranges can only be applied
  496. # to asserts.)
  497. #
  498. # A perhaps better long term plan is to avoid this order
  499. # dependence by making it possible to refine ranges on
  500. # arbitrary expressions, not just symbols. But it is not
  501. # so easy to make use of this information, see
  502. # https://twitter.com/ezyang/status/1745801370299482492
  503. # We actually made an attempt at this in
  504. # https://github.com/pytorch/pytorch/pull/119043
  505. # which didn't work.
  506. #
  507. # Another ideas for how to do this:
  508. # - Have bound_sympy be the source of truth of the ranges of any expression
  509. # - Cache intermediate results for every subexpression of bound_sympy
  510. # - This cache should be possible to edit to refine ranges
  511. #
  512. # One issue with this proposal is that if
  513. # we have a bound on 2x, we are not going to be able to
  514. # apply it for 4x. Similarly, we may have bounds for an
  515. # equivalent expression that we are not applying because
  516. # it's not a perfect match (e.g. x < y vs y > x)".
  517. #
  518. # The first issue we already have it and it's impossible
  519. # to solve in general, so any implementation on a best
  520. # effort basis should do.
  521. #
  522. # The second issue is a preexisting one. It can be mitigated
  523. # with a normalization algorithm. In general, it may also
  524. # be on a best effort basis, but since our grammar is not
  525. # terribly difficult, chances are we could even fully
  526. # normalize SymPy expressions... who knows.
  527. if i0 in constrained_unbacked_symbols:
  528. continue # constrain symbol just once
  529. vr = shape_env.var_to_range[i0]
  530. if vr.is_int and vr.upper == sys.maxsize - 1:
  531. # treat upper bound == sys.maxsize - 1 for int symbols as +oo
  532. # to avoid redundant runtime assert
  533. vr = ValueRanges(vr.lower, int_oo)
  534. if not shape_env._default_unspecified_value_range().issubset(vr):
  535. # The runtime range is constrained, so add a runtime
  536. # assert and also explicitly refine the range
  537. # (refinement should not be necessary once runtime
  538. # asserts cause refinement, but that's NYI)
  539. def convert(s):
  540. if s in (int_oo, -int_oo):
  541. return None
  542. try:
  543. return int(s)
  544. except TypeError:
  545. return None
  546. if (
  547. expr_to_proxy[i0].node.target
  548. is not cast_symbool_to_symint_guardless
  549. ):
  550. # TODO(pianpwk): calling sym_constrain_range_for_size or adding bound asserts
  551. # raises AOTAutograd errors on cast_symbool_to_symint_guardless
  552. if (min_val := convert(vr.lower)) is not None:
  553. ge = _sympy_interp(expr_to_proxy, i0 >= min_val).node
  554. graph.call_function(
  555. torch.ops.aten._assert_scalar.default,
  556. (
  557. ge,
  558. f"Runtime assertion failed for expression {i0 >= min_val} on node '{ge}'",
  559. ),
  560. )
  561. added_asserts.add(i0 >= min_val)
  562. if (max_val := convert(vr.upper)) is not None:
  563. le = _sympy_interp(expr_to_proxy, i0 <= max_val).node
  564. graph.call_function(
  565. torch.ops.aten._assert_scalar.default,
  566. (
  567. le,
  568. f"Runtime assertion failed for expression {i0 <= max_val} on node '{le}'",
  569. ),
  570. )
  571. added_asserts.add(i0 <= max_val)
  572. constrained_unbacked_symbols.add(i0)
  573. add_runtime_asserts(ras)
  574. # delete unused reified symbols
  575. for expr, proxy in expr_to_proxy.items():
  576. if (
  577. isinstance(expr, sympy.Symbol)
  578. and proxy.node.op != "placeholder" # keep placeholders intact
  579. and not proxy.node.users
  580. ):
  581. log.debug("deleting unused reified symbol for %s", expr)
  582. gm.graph.erase_node(proxy.node)