validator.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892
  1. # mypy: allow-untyped-defs
  2. import builtins
  3. import functools
  4. import logging
  5. import math
  6. import operator
  7. from collections.abc import Callable
  8. from dataclasses import dataclass
  9. from typing import Any, Optional, Union
  10. import sympy
  11. import torch
  12. import torch.fx
  13. import torch.fx.traceback as fx_traceback
  14. from torch._dynamo.exc import TorchDynamoException
  15. from torch._dynamo.utils import dynamo_timed
  16. from torch.fx.node import Argument, Target
  17. from torch.utils._sympy.interp import sympy_interp
  18. log = logging.getLogger(__name__)
  19. try:
  20. import z3 # type: ignore[import]
  21. # Translation Validation for Dynamo guards
  22. # ========================================
  23. #
  24. # Checks whether optimizations applied to the collected guards are
  25. # valid. In other words, whether the guard function we actually run
  26. # does not have false positives (unsound).
  27. #
  28. # In order to do so, we build the guards using 2 different information
  29. # attached to each 'SymNode':
  30. # 1. SymPy expressions
  31. # 2. FX nodes
  32. #
  33. # SymPy expressions have implicit optimizations baked within itself,
  34. # which may have a few bugs. On the other hand, we build the FX graph
  35. # manually, with no optimizations enabled. This gives us access to
  36. # the "ground truth".
  37. #
  38. # We then convert into Z3 expressions both the SymPy expressions
  39. # (see [Note: SympyToZ3]) that reach 'ShapeEnv.produce_guards' function
  40. # and the FX nodes (see [Note: PopulateValidator]) that go through
  41. # 'ShapeEnv.evaluate_expr' function. Finally, we run the validation.
  42. # (see [Note: TranslationValidator])
  43. # Better Z3 to string implementation (for a small fraction of Z3).
  44. #
  45. # Here are the things we clean before showing the Z3 expression:
  46. # - Rename a few ops (e.g. "Distinct" ==> "!=")
  47. #
  48. # - Ignore ToInt and ToReal operations:
  49. # usually they don't really matter
  50. #
  51. # - Transform (ToInt (/ ...)) into (idiv ...):
  52. # this is the pattern for floor division
  53. #
  54. # - Collect a chain of the same operations into one
  55. def z3str(e: z3.ExprRef) -> str:
  56. if not z3.is_expr(e):
  57. raise AssertionError(f"unsupported expression type: {e}")
  58. def get_args_str(e: z3.ExprRef) -> list[str]:
  59. return [z3str(e.arg(i)) for i in range(e.num_args())]
  60. # First, we simplify the given expression.
  61. # This is done using rewriting rules, so shouldn't take long.
  62. e = z3.simplify(e)
  63. # Only support function applications.
  64. # Even Z3 "variables" are, in fact, function applications.
  65. if not z3.is_app(e):
  66. raise ValueError(f"can't print Z3 expression: {e}")
  67. if z3.is_int_value(e) or z3.is_rational_value(e):
  68. return e.as_string() # type: ignore[attr-defined]
  69. decl = e.decl()
  70. kind = decl.kind()
  71. op = str(decl)
  72. args = get_args_str(e)
  73. if kind == z3.Z3_OP_POWER:
  74. op = "pow"
  75. elif kind in (z3.Z3_OP_ADD, z3.Z3_OP_MUL):
  76. # Collect the arguments of chains of ADD and MUL.
  77. # This is safe, since they are associative.
  78. def collect_str_args(e):
  79. if not (z3.is_app(e) and e.decl().kind() == kind):
  80. return [z3str(e)]
  81. else:
  82. return [
  83. x
  84. for i in range(e.num_args())
  85. for x in collect_str_args(e.arg(i))
  86. ]
  87. args = collect_str_args(e)
  88. elif kind == z3.Z3_OP_NOT:
  89. # Revert some conversions that z3.simplify applies:
  90. # - a != b ==> (Not (== a b)) ==> (!= a b)
  91. # - a < b ==> (Not (<= b a)) ==> (> b a)
  92. # - a > b ==> (Not (<= a b)) ==> (> a b)
  93. if e.num_args() != 1:
  94. raise AssertionError(f"Expected 1 arg, got {e.num_args()}")
  95. arg = e.arg(0)
  96. if not z3.is_app(arg):
  97. raise AssertionError("Expected z3 app")
  98. argkind = arg.decl().kind()
  99. logic_inverse = {
  100. z3.Z3_OP_EQ: "!=",
  101. z3.Z3_OP_LE: ">",
  102. z3.Z3_OP_GE: "<",
  103. }
  104. if argkind in logic_inverse:
  105. op = logic_inverse[argkind]
  106. args = get_args_str(arg)
  107. elif kind in (z3.Z3_OP_TO_INT, z3.Z3_OP_TO_REAL):
  108. if e.num_args() != 1:
  109. raise AssertionError(f"Expected 1 arg, got {e.num_args()}")
  110. argstr = z3str(e.arg(0))
  111. # Check if it's the floor division pattern.
  112. if argstr.startswith("(/"):
  113. return "(idiv" + argstr[2:]
  114. # Otherwise, just ignore it.
  115. return argstr
  116. elif kind == z3.Z3_OP_UNINTERPRETED:
  117. if e.num_args() != 0:
  118. raise AssertionError(f"Expected 0 args, got {e.num_args()}")
  119. return str(decl)
  120. string = op + " " + " ".join(args)
  121. return f"({string.rstrip()})"
  122. # We need to convert to/from BitVec in order to use z3 bitwise ops.
  123. # We assume that integers are 64 bit.
  124. # If all args are boolean, then use the boolean bitwise op implementation instead, if provided.
  125. def _bitwise_op(bitwise_func, bool_func):
  126. @functools.wraps(bitwise_func)
  127. def wrapper(self, *args):
  128. if bool_func is not None and all(
  129. isinstance(arg, z3.BoolRef) for arg in args
  130. ):
  131. return bool_func(*args)
  132. wrapped_args = tuple(z3.Int2BV(a, 64) for a in args)
  133. return z3.BV2Int(bitwise_func(*wrapped_args))
  134. return wrapper
  135. # Implementation of Python semantics as Z3 expressions.
  136. #
  137. # Z3 Real-Int theory has operators with semantics that differ that of
  138. # Python. Therefore, in order to get it right, we need to implement
  139. # the (Python) semantics we are relying on in Z3.
  140. @dataclass
  141. class _Z3Ops:
  142. # Validator used for adding assertions as needed.
  143. # e.g. div(a, b) requires b != 0.
  144. validator: "TranslationValidator"
  145. # The 2 functions below are used for conditionally casting between
  146. # integer and reals.
  147. #
  148. # Returns a real expression from 'x'.
  149. @staticmethod
  150. def to_real(x: z3.ArithRef) -> z3.ArithRef:
  151. return x if x.is_real() else z3.ToReal(x)
  152. # Returns an integer expression from 'x'.
  153. @staticmethod
  154. def to_int(x: z3.ArithRef) -> z3.ArithRef:
  155. return x if x.is_int() else z3.ToInt(x)
  156. def sym_sum(self, args: z3.ArithRef) -> z3.ArithRef:
  157. return sum(args) # pyrefly: ignore [no-matching-overload]
  158. # Implements Python division semantics.
  159. def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
  160. self.validator.add_assertion(denominator != 0) # type: ignore[arg-type]
  161. return _Z3Ops.to_real(numerator) / _Z3Ops.to_real(denominator)
  162. def floor(self, number: z3.ArithRef) -> z3.ArithRef:
  163. # Z3 ToInt function rounds a real number towards negative infinity.
  164. return _Z3Ops.to_int(number)
  165. # Python semantics for 'FloorDiv' states that before applying the floor
  166. # function, the operands are converted to their common type.
  167. def floordiv(
  168. self, numerator: z3.ArithRef, denominator: z3.ArithRef
  169. ) -> z3.ArithRef:
  170. cast_result_to_real = numerator.is_real() or denominator.is_real()
  171. result = _Z3Ops.to_int(self.div(numerator, denominator))
  172. # Since the 'result' is already an integer, we just have to check
  173. # whether we should cast it to real.
  174. return _Z3Ops.to_real(result) if cast_result_to_real else result
  175. def ceil(self, number: z3.ArithRef) -> z3.ArithRef:
  176. return z3.If(self.floor(number) < number, self.floor(number + 1), number) # type: ignore[return-value]
  177. def trunc(self, number: z3.ArithRef) -> z3.ArithRef:
  178. return z3.If(number >= 0, self.floor(number), self.ceil(number)) # type: ignore[return-value]
  179. def max(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef:
  180. return z3.If(a > b, a, b) # type: ignore[return-value]
  181. def min(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef:
  182. return z3.If(a < b, a, b) # type: ignore[return-value]
  183. # Python semantics for 'Mod' is defined as: p % q = p - floordiv(p, q) * q
  184. # It should work with both integer and reals.
  185. def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef:
  186. return p - self.floordiv(p, q) * q
  187. def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef:
  188. # Z3 can't handle complex numbers very well.
  189. self.validator.add_assertion(z3.Or(base != 0, exp > 0)) # type: ignore[arg-type]
  190. return base**exp
  191. def sqrt(self, number: z3.ArithRef) -> z3.ArithRef:
  192. # Square-root:
  193. # 1. Only work with reals
  194. number = _Z3Ops.to_real(number)
  195. # 2. The number should be positive or zero.
  196. # Otherwise, Z3 returns 'unknown'.
  197. self.validator.add_assertion(number >= 0)
  198. return number**0.5
  199. def abs(self, number: z3.ArithRef) -> z3.ArithRef:
  200. return z3.Abs(number)
  201. def round_to_int(self, number: z3.ArithRef) -> z3.ArithRef:
  202. # Pythons builtin 'round' implements the 'round half to even' strategy
  203. # See https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even
  204. # z3 has an equivalent z3.fpRoundToIntegral(z3.RoundNearestTiesToEven(), ...), but this only applies to
  205. # floating point numbers, which is different from real numbers that we are dealing with here.
  206. # Instead, we implement 'round half to even' in terms of 'round half up' (floor(x + 0.5)) and
  207. # 'round half down' (ceil(x - 0.5)).
  208. # Assuming 'round half up' is the default case, we need to correct ..., -3.5, -1.5, 0.5, 2.5, 4.5, ...
  209. # to round down, i.e. use the 'round half down' strategy
  210. return z3.If(
  211. self.mod(number, z3.IntVal(2)) == 0.5,
  212. self.ceil(number - 0.5),
  213. self.floor(number + 0.5),
  214. )
  215. bitwise_and = _bitwise_op(operator.and_, z3.And)
  216. bitwise_or = _bitwise_op(operator.or_, z3.Or)
  217. lshift = _bitwise_op(operator.lshift, None)
  218. rshift = _bitwise_op(operator.rshift, None)
  219. # Lifts a callable to be used in Z3.
  220. #
  221. # This function replaces the given 'op' by a function that:
  222. #
  223. # 1. Lifts the arguments into Z3 (i.e. make them inhabitants of Z3)
  224. #
  225. # 2. Calls an operation that corresponds to 'op', but works with Z3
  226. # inhabitants (left as is if it works as is)
  227. def z3op(op: Callable, validator: "TranslationValidator") -> Callable:
  228. # Operations that have booleans as their argument.
  229. # This is needed because the argument of some FX nodes were
  230. # literal integers, instead of booleans. So, whenever this flag
  231. # is set, we also convert ints to booleans.
  232. boolean_ops = {operator.not_}
  233. as_bool = op in boolean_ops
  234. # Lifts the function into 'z3.ExprRef' domain.
  235. def lift(func):
  236. def wrap(a) -> z3.ExprRef:
  237. if isinstance(a, (z3.ArithRef, z3.BoolRef)):
  238. return a
  239. # Convert it into a Z3 value, if it is some of the supported
  240. # types below.
  241. if isinstance(a, bool) or (as_bool and isinstance(a, int)):
  242. return z3.BoolVal(bool(a))
  243. if isinstance(a, (int, sympy.Integer)):
  244. return z3.IntVal(int(a))
  245. if isinstance(a, (float, sympy.Float)):
  246. return z3.RealVal(float(a))
  247. raise ValueError(f"can't lift type: {type(a)}")
  248. @functools.wraps(func)
  249. def wrapper(*args):
  250. # Lifts the arguments into a list of Z3 inhabitants.
  251. if len(args) == 1 and isinstance(args[0], (list, tuple)):
  252. wrapped_args = (tuple(wrap(a) for a in args[0]),)
  253. else:
  254. wrapped_args = tuple(wrap(a) for a in args)
  255. # Run the function on the Z3 expressions.
  256. return func(*wrapped_args)
  257. return wrapper
  258. ops = _Z3Ops(validator)
  259. replacement_map = {
  260. # Operator module.
  261. operator.not_: lift(z3.Not),
  262. operator.and_: lift(ops.bitwise_and),
  263. operator.or_: lift(ops.bitwise_or),
  264. operator.lshift: lift(ops.lshift),
  265. operator.rshift: lift(ops.rshift),
  266. operator.floordiv: lift(ops.floordiv),
  267. operator.truediv: lift(ops.div),
  268. operator.mod: lift(ops.mod),
  269. operator.abs: lift(ops.abs),
  270. builtins.round: lift(ops.round_to_int),
  271. # Math module.
  272. math.ceil: lift(ops.ceil),
  273. math.floor: lift(ops.floor),
  274. math.trunc: lift(ops.trunc),
  275. # Torch module.
  276. torch.sym_float: lift(ops.to_real),
  277. torch.sym_max: lift(ops.max),
  278. torch.sym_min: lift(ops.min),
  279. torch.sym_sum: lift(ops.sym_sum),
  280. torch.sym_ite: lift(lambda b, t, f: t if b else f),
  281. torch._sym_sqrt: lift(ops.sqrt), # type: ignore[attr-defined]
  282. # Not lifted because we only use this function as a
  283. # marker for adding the expression as validator input.
  284. torch._assert: torch._assert,
  285. }
  286. return replacement_map[op] if op in replacement_map else lift(op)
  287. # Processes an FX graph, populating the given validator.
  288. #
  289. # [Note: PopulateValidator]
  290. # This class walks through each node in the FX graph, translating
  291. # them into the Z3 world.
  292. #
  293. # Then, whenever it finds an 'torch._assert' call_function operation,
  294. # it adds the Z3 expression corresponding to the argument as validator
  295. # input.
  296. class PopulateValidator(torch.fx.Interpreter):
  297. def __init__(self, graph: torch.fx.Graph, validator: "TranslationValidator"):
  298. # Reference to the translation validator.
  299. self.validator = validator
  300. # Build the graph module and call `Interpreter` constructor.
  301. module = torch.fx.GraphModule(root={}, graph=graph)
  302. super().__init__(module, garbage_collect_values=True)
  303. def placeholder(
  304. self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
  305. ) -> Any:
  306. symbol = fx_traceback.get_current_meta()["symbol"]
  307. return self.validator.z3var(symbol)
  308. def call_function(
  309. self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
  310. ) -> Any:
  311. if target is not torch._assert:
  312. # Lift and runs the node target function
  313. return super().call_function(z3op(target, self.validator), args, kwargs) # type: ignore[arg-type]
  314. # Adds the Z3 expression corresponding to the first argument
  315. # as a validator input.
  316. if len(args) != 1:
  317. raise AssertionError(
  318. f"expected 1 argument on assertion. Got: {len(args)} "
  319. )
  320. self.validator.add_source_expr(args[0]) # type: ignore[arg-type]
  321. # Translates SymPy expressions into Z3 expressions.
  322. #
  323. # [Note: SympyToZ3]
  324. # At the time of the translation, all free variables present in the
  325. # SymPy expression being translated must be already mapped to a Z3
  326. # integer variable.
  327. class SympyToZ3:
  328. OPERATOR_HANDLES = {"add", "mul", "eq", "ne", "lt", "gt", "le", "ge"}
  329. def __init__(
  330. self,
  331. validator: "TranslationValidator",
  332. ) -> None:
  333. self._validator = validator
  334. self._ops = _Z3Ops(self._validator)
  335. def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef:
  336. # TODO: Probably OK to relax this and allow lower precision
  337. if dtype is torch.int64:
  338. return z3.IntVal(int(value))
  339. if dtype is torch.double:
  340. return z3.RealVal(float(value))
  341. if dtype is torch.bool:
  342. return z3.BoolVal(bool(value))
  343. raise ValueError(f"unsupported dtype (SympyToZ3): {dtype}")
  344. def to_dtype(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
  345. if dtype == torch.float64:
  346. return z3.ToReal(x)
  347. raise NotImplementedError(f"to_dtype {dtype} NYI")
  348. def trunc_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
  349. return z3.ToInt(x)
  350. def round_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
  351. return self._ops.round_to_int(x)
  352. def int_truediv(
  353. self, numerator: z3.ArithRef, denominator: z3.ArithRef
  354. ) -> z3.ArithRef:
  355. return self._ops.div(numerator, denominator)
  356. def truediv(
  357. self, numerator: z3.ArithRef, denominator: z3.ArithRef
  358. ) -> z3.ArithRef:
  359. return self._ops.div(numerator, denominator)
  360. def floordiv(
  361. self, numerator: z3.ArithRef, denominator: z3.ArithRef
  362. ) -> z3.ArithRef:
  363. return self._ops.floordiv(numerator, denominator)
  364. def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
  365. return self._ops.floordiv(numerator, denominator)
  366. def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef:
  367. return self._ops.pow(base, exp)
  368. def pow_by_natural(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef:
  369. return self._ops.pow(base, exp)
  370. def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef:
  371. return self._ops.mod(p, q)
  372. def python_mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef:
  373. return self._ops.mod(p, q)
  374. def ceil_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
  375. return self._ops.ceil(x)
  376. def floor_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
  377. return self._ops.floor(x)
  378. def __getattr__(self, name: str) -> Any:
  379. REPLACEMENT = {
  380. "and_": z3.And,
  381. "or_": z3.Or,
  382. "not_": z3.Not,
  383. "bitwise_and": self._ops.bitwise_and,
  384. "bitwise_or": self._ops.bitwise_or,
  385. "lshift": self._ops.lshift,
  386. "rshift": self._ops.rshift,
  387. "floor": self._ops.floor,
  388. "ceil": self._ops.ceil,
  389. "minimum": self._ops.min,
  390. "maximum": self._ops.max,
  391. }
  392. if name in REPLACEMENT:
  393. return REPLACEMENT[name]
  394. if name in self.OPERATOR_HANDLES:
  395. return getattr(operator, name)
  396. raise AttributeError(f"unhandled operator: {name}")
  397. def run(self, expr: sympy.Basic) -> z3.ExprRef:
  398. return sympy_interp(self, self._validator.symbols, expr) # type: ignore[arg-type]
  399. # Dynamo guards translation validator.
  400. #
  401. # [Note: TranslationValidator]
  402. # Verifies whether the guards issued by 'ShapeEnv.produce_guards' are sound.
  403. # That is: whether those (target) guards only yield TRUE whenever the original,
  404. # unoptimized, (source) guards yield TRUE.
  405. #
  406. # More concretely, given 'source' and 'target' guard expressions, we wish to
  407. # check whether the following expression holds:
  408. #
  409. # Not(And(source)) AND And(target)
  410. #
  411. # i.e. whether there is an assignment of the free variables where the opposite
  412. # happens: target is TRUE, but source is FALSE.
  413. class TranslationValidator:
  414. def __init__(self) -> None:
  415. log.debug("new instance")
  416. # Mapping of SymPy symbols to Z3 variables.
  417. self.symbols: dict[sympy.Symbol, z3.ExprRef] = {}
  418. # Set of source Z3 expressions.
  419. # They represent the generated guards without any kind of
  420. # simplification or transformation.
  421. self._source_exprs: set[z3.BoolRef] = set()
  422. # Set of target Z3 expressions.
  423. # They represent the actual checked guards at runtime. They might
  424. # be simplified or transformed versions of the source guards.
  425. self._target_exprs: set[z3.BoolRef] = set()
  426. # Set of Z3 expressions representing assertions over both the
  427. # source and target expressions.
  428. self._assertions: set[z3.BoolRef] = set()
  429. # Retrieves the corresponding Z3 variable.
  430. def z3var(self, symbol: sympy.Symbol) -> z3.ExprRef:
  431. if symbol not in self.symbols:
  432. raise AssertionError(f"Z3 variable not found for: {symbol}")
  433. return self.symbols[symbol]
  434. # Create a variable in Z3 of 'type' for 'symbol', if it doesn't already exists.
  435. def add_var(self, symbol: sympy.Symbol, type: type) -> z3.ExprRef:
  436. if symbol in self.symbols:
  437. return self.symbols[symbol]
  438. log.debug("new variable: %s (%s)", symbol.name, type.__name__)
  439. if type is int:
  440. var = z3.Int(symbol.name)
  441. # If 'symbol' is positive (SymPy assumption), we have to
  442. # convey it to Z3 as well.
  443. if symbol.is_positive: # type: ignore[attr-defined]
  444. self._target_exprs.add(var > 0)
  445. elif type is float:
  446. var = z3.Real(symbol.name)
  447. elif type is bool:
  448. var = z3.Bool(symbol.name)
  449. else:
  450. raise RuntimeError(f"unsupported type for Z3 variable: {type}")
  451. self.symbols[symbol] = var
  452. return var
  453. # Checks whether all symbols were already added.
  454. def _check_freesymbols(self, e: sympy.Basic) -> None:
  455. for s in e.free_symbols:
  456. if not isinstance(s, sympy.Symbol):
  457. raise AssertionError(f"Expected sympy.Symbol, got {type(s)}")
  458. # Call 'z3var' just to check whether there's already a
  459. # Z3 variable corresponding to 's'.
  460. self.z3var(s)
  461. def to_z3_boolean_expr(self, e: sympy.Basic) -> z3.BoolRef:
  462. z3expr = SympyToZ3(self).run(e)
  463. if not isinstance(z3expr, z3.BoolRef):
  464. raise AssertionError(f"expected boolean expression. Got: {z3expr}")
  465. return z3expr
  466. def add_source_expr(self, e: z3.BoolRef) -> None:
  467. if e not in self._source_exprs:
  468. log.debug("add source guard: %s", z3str(e))
  469. self._source_exprs.add(e)
  470. def add_target_expr(self, e: "sympy.logic.boolalg.Boolean") -> None:
  471. self._check_freesymbols(e)
  472. z3expr = self.to_z3_boolean_expr(e)
  473. if e not in self._target_exprs:
  474. log.debug("add target guard: %s", z3str(z3expr))
  475. self._target_exprs.add(z3expr)
  476. def add_assertion(self, e: Union[z3.BoolRef, sympy.Basic]) -> None:
  477. if isinstance(e, sympy.Basic):
  478. self._check_freesymbols(e)
  479. ref = self.to_z3_boolean_expr(e)
  480. else:
  481. ref = e
  482. if not isinstance(ref, z3.BoolRef):
  483. raise AssertionError(f"Expected z3.BoolRef, got {type(ref)}")
  484. if ref not in self._assertions:
  485. log.debug("add assertion: %s", z3str(ref))
  486. self._assertions.add(ref)
  487. def validate(self) -> None:
  488. with dynamo_timed("TranslationValidator.validate"):
  489. return self._validate()
  490. def _validate(self) -> None:
  491. if len(self._source_exprs) == 0 or len(self._target_exprs) == 0:
  492. # If there are no source/target expressions, there's nothing we really
  493. # wish to prove. So, we just return.
  494. return None
  495. # Here, we use "QF_NRA" logic for the solver:
  496. # "Quantifier-free Non-linear Real Arithmetic".
  497. #
  498. # Most of the guards expressions have:
  499. # 1. arithmetic between integer and reals
  500. # 2. no quantifiers
  501. # 3. potentially non-linear.
  502. #
  503. # Although there's also "QF_NIRA" (mixed integer-real arithmetic),
  504. # "QF_NRA" seems to work better on 'dynamo/test_dynamic_shapes.py'.
  505. solver = z3.SolverFor("QF_NRA")
  506. # Set a timeout for finding a solution.
  507. solver.set(timeout=translation_validation_timeout())
  508. # Add all the assertions to the solver.
  509. for assertion in self._assertions:
  510. solver.add(assertion)
  511. # "Is there any case where it's TRUE for the target expressions,
  512. # but FALSE for the source expressions?"
  513. solver.add(z3.Not(z3.And(*self._source_exprs)))
  514. solver.add(*self._target_exprs)
  515. log.debug("translation validation: start")
  516. r = solver.check()
  517. if r == z3.sat:
  518. # Target expressions are unsound.
  519. # Log the found model and the source expressions that failed.
  520. model = solver.model()
  521. raise ValidationException(
  522. model,
  523. self._assertions,
  524. self._target_exprs,
  525. failed_source_exprs=[
  526. inp for inp in self._source_exprs if not model.evaluate(inp)
  527. ],
  528. )
  529. else:
  530. if r == z3.unknown:
  531. # Could not find a solution. It didn't fail, but it also
  532. # didn't succeed. Canceling the validation execution (keyboard
  533. # interrupt) also gets to this branch.
  534. log.warning(
  535. "translation validation: could not validate: got z3.unknown"
  536. )
  537. else:
  538. # Target expressions are sound.
  539. if r != z3.unsat:
  540. raise AssertionError(f"Expected z3.unsat, got {r}")
  541. log.debug("translation validation: success")
  542. except ImportError:
  543. _HAS_Z3 = False
  544. __all__ = [
  545. "translation_validation_enabled",
  546. "translation_validation_timeout",
  547. "ValidationException",
  548. "BisectValidationException",
  549. ]
  550. else:
  551. _HAS_Z3 = True
  552. __all__ = [
  553. "z3str",
  554. "z3op",
  555. "PopulateValidator",
  556. "SympyToZ3",
  557. "TranslationValidator",
  558. "translation_validation_enabled",
  559. "translation_validation_timeout",
  560. "ValidationException",
  561. "BisectValidationException",
  562. ]
  563. from torch.fx.experimental import _config as config
  564. def translation_validation_enabled() -> bool:
  565. # Checks every time this function is called, in case the Dynamo
  566. # option is set, but Z3 is not installed.
  567. _assert_z3_installed_if_tv_set()
  568. return _HAS_Z3 and config.translation_validation
  569. def translation_validation_timeout() -> int:
  570. return config.translation_validation_timeout
  571. def _assert_z3_installed_if_tv_set():
  572. if not (_HAS_Z3 or not config.translation_validation):
  573. raise AssertionError(
  574. "translation validation requires Z3 package. Please, either install "
  575. "z3-solver or disable translation validation."
  576. )
  577. class ValidationException(TorchDynamoException):
  578. def __init__(self, model, assertions, target_exprs, failed_source_exprs):
  579. if not _HAS_Z3:
  580. raise AssertionError("Z3 is required")
  581. def symbolstr(sym) -> str:
  582. return f"{sym}: {model[sym]}"
  583. def joinlines(xs) -> str:
  584. return "\n".join(f" ==> {x}" for x in xs)
  585. model_str = joinlines(sorted(map(symbolstr, model)))
  586. assertions_str = joinlines(sorted(map(z3str, assertions)))
  587. target_exprs_str = joinlines(sorted(map(z3str, target_exprs)))
  588. failed_source_exprs_str = joinlines(sorted(map(z3str, failed_source_exprs)))
  589. self.msg = "translation validation failed."
  590. self.details = f"""\
  591. Model:
  592. {model_str}
  593. Assertions:
  594. {assertions_str}
  595. Target Expressions:
  596. {target_exprs_str}
  597. Failed Source Expressions:
  598. {failed_source_exprs_str}"""
  599. def __str__(self):
  600. return f"{self.msg}\n\n{self.details}"
  601. class BisectValidationException(TorchDynamoException):
  602. def __init__(self, validation_exc, expr, failed_action, traced_node):
  603. self.msg = f"translation validation failed when {failed_action}: {expr}"
  604. self.details = f"""\
  605. Failure occurred while running node:
  606. {traced_node.format_node()}
  607. {validation_exc.details}"""
  608. def __str__(self):
  609. return f"{self.msg}\n\n{self.details}"
  610. # Checks when this module is loaded.
  611. _assert_z3_installed_if_tv_set()
  612. # Translation validation bisection.
  613. #
  614. # Bisect into the torch._assert nodes recorded in the shape_env FX graph, and raise
  615. # the earliest ValidationException.
  616. #
  617. # As guards are added by ShapeEnv.evaluate_expr calls, some simplification errors
  618. # might be silently happening. This function tries to nail down exactly at which
  619. # point things went wrong from a validation perspective.
  620. def bisect(shape_env):
  621. from torch.fx.experimental.recording import (
  622. FakeTensorMeta,
  623. replay_shape_env_events,
  624. ShapeEnvEvent,
  625. )
  626. from torch.fx.experimental.symbolic_shapes import (
  627. CURRENT_NODE_KEY,
  628. ShapeEnv,
  629. SHAPEENV_EVENT_KEY,
  630. )
  631. events = shape_env.events
  632. # Retrieves the ShapeEnvEvent associated with node.
  633. def get_node_event(node: torch.fx.Node) -> ShapeEnvEvent:
  634. if SHAPEENV_EVENT_KEY not in node.meta:
  635. raise AssertionError("SHAPEENV_EVENT_KEY not in node.meta")
  636. return events[node.meta[SHAPEENV_EVENT_KEY]]
  637. # Creates a new instance of fake, but updating every symbolic value's ShapeEnv
  638. # reference to the one given as argument.
  639. #
  640. # This is needed so as not to simplify a symbolic expression using a ShapeEnv
  641. # "from the future", where it may have a different set of replacements.
  642. def new_with_shape_env(shape_env: ShapeEnv, fake) -> Any:
  643. if isinstance(fake, int):
  644. return fake
  645. if isinstance(fake, torch.SymInt):
  646. return torch.SymInt(fake.node.with_shape_env(shape_env))
  647. if isinstance(fake, torch.SymFloat):
  648. return torch.SymFloat(fake.node.with_shape_env(shape_env))
  649. if not isinstance(fake, FakeTensorMeta):
  650. raise AssertionError(f"Expected FakeTensorMeta, got {type(fake)}")
  651. return FakeTensorMeta(
  652. tuple(new_with_shape_env(shape_env, s) for s in fake.size()),
  653. tuple(new_with_shape_env(shape_env, s) for s in fake.stride()),
  654. new_with_shape_env(shape_env, fake.storage_offset()),
  655. fake.is_nested,
  656. )
  657. # Checks whether the given shape_env fails when produce_guards is called.
  658. def check_shapeenv_fails(
  659. shape_env: ShapeEnv, tracked_fakes: Optional[list[Any]]
  660. ) -> Optional[ValidationException]:
  661. if tracked_fakes is None:
  662. raise AssertionError("tracked_fakes is None")
  663. try:
  664. # This produce_guards call is a best-effort replication, since we
  665. # don't populate EqualityConstraint list. Reason: we would also have
  666. # to save OutputGraph.tracked_fakes_id_to_source.
  667. shape_env.produce_guards(
  668. [new_with_shape_env(shape_env, a.fake) for a in tracked_fakes],
  669. [a.source for a in tracked_fakes],
  670. input_contexts=[a.symbolic_context for a in tracked_fakes],
  671. )
  672. return None
  673. except ValidationException as e:
  674. return e
  675. # Checks whether the ShapeEnv reconstructed by replaying the events until
  676. # node is created fails when produce_guards is called.
  677. def check_node_fails(node: torch.fx.Node) -> Optional[ValidationException]:
  678. number = node.meta[SHAPEENV_EVENT_KEY]
  679. # Reconstruct shape_env until the event at event_number.
  680. shape_env = replay_shape_env_events(events[: number + 1])
  681. shape_env.graph.lint()
  682. return check_shapeenv_fails(shape_env, events[number].tracked_fakes)
  683. last_exception = check_shapeenv_fails(
  684. shape_env, shape_env._snapshot_tracked_fakes()
  685. )
  686. if not last_exception:
  687. # We don't actually fail due to a produce_guards call.
  688. # Stop and don't bisect.
  689. log.info("translation validation succeeded: no errors found.")
  690. return
  691. if not shape_env.should_record_events or config.translation_validation_no_bisect:
  692. # Bisection is off.
  693. # Return the last ValidationException we got.
  694. raise last_exception
  695. # Cache the raised exception (if any) at each bisection point.
  696. exception = {}
  697. # Bisection happens on the assertion nodes of the recorded FX graph for
  698. # dynamic shapes.
  699. assert_nodes = [
  700. node for node in shape_env.graph.nodes if node.target is torch._assert
  701. ]
  702. # Preparing the indices for binary search.
  703. # The overall invariants are
  704. # - for all i < left, assert_node[i] doesn't fail
  705. # - for all i >= right, assert_node[i] fails
  706. # - `right in exception` always holds
  707. # - `left <= right` always holds
  708. left, mid, right = 0, 0, len(assert_nodes) - 1
  709. exception[right] = check_node_fails(assert_nodes[right])
  710. while left < right:
  711. mid = (left + right) // 2
  712. node = assert_nodes[mid]
  713. log.debug("bisecting at %s: %s", mid, get_node_event(node))
  714. # Check whether the new shape_env raises a ValidationException or not.
  715. exception[mid] = check_node_fails(node)
  716. if exception[mid]:
  717. right = mid
  718. else:
  719. left = mid + 1
  720. if not (left in exception and isinstance(exception[left], ValidationException)):
  721. raise AssertionError("Expected ValidationException at bisect result")
  722. node = assert_nodes[left]
  723. event = get_node_event(node)
  724. if event.is_evaluate_expr():
  725. failed_action = "evaluating"
  726. else:
  727. if not event.is_defer_runtime_assert():
  728. raise AssertionError(f"unexpected event type: {event}")
  729. failed_action = "adding runtime assert"
  730. args = event.args
  731. if args is None:
  732. raise AssertionError("event.args is None")
  733. if len(args) < 2:
  734. raise AssertionError(
  735. f"bisecting expects {event.name} to have at least 2 positional arguments. "
  736. f"Got: {len(args)}"
  737. )
  738. if not isinstance(args[1], sympy.Basic):
  739. raise AssertionError(
  740. f"bisecting expects {event.name} to have a SymPy expression as its second "
  741. f"argument. Got: {type(args[1])}"
  742. )
  743. raise BisectValidationException(
  744. exception[left],
  745. expr=args[1],
  746. failed_action=failed_action,
  747. traced_node=node.meta[CURRENT_NODE_KEY],
  748. )