solve.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import logging
  2. import sympy
  3. from torch.utils._sympy.functions import FloorDiv
  4. log = logging.getLogger(__name__)
  5. _MIRROR_REL_OP: dict[type[sympy.Basic], type[sympy.Rel]] = {
  6. sympy.Eq: sympy.Eq,
  7. sympy.Ne: sympy.Ne,
  8. sympy.Ge: sympy.Le,
  9. sympy.Gt: sympy.Lt,
  10. sympy.Le: sympy.Ge,
  11. sympy.Lt: sympy.Gt,
  12. }
  13. INEQUALITY_TYPES = (sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le)
  14. def mirror_rel_op(type: type) -> type[sympy.Rel] | None:
  15. return _MIRROR_REL_OP.get(type)
  16. # Tries to simplify 'expr', so as to leave only 'thing' in the left-hand side.
  17. #
  18. # Returns a tuple of:
  19. # 1. The simplified expression
  20. # 2. The expression on the right-hand side
  21. #
  22. # Returns 'None' if it can't reach a state where the only thing in the left
  23. # hand side is 'thing'.
  24. #
  25. # 'trials': number of times 'try_solve' will try to isolate 'thing' to the
  26. # left-hand side.
  27. #
  28. # 'floordiv_inequality': flag to enable conversion of 'FloorDiv' into
  29. # inequalities.
  30. def try_solve(
  31. expr: sympy.Basic,
  32. thing: sympy.Basic,
  33. trials: int = 5,
  34. floordiv_inequality: bool = True,
  35. ) -> tuple[sympy.Rel, sympy.Expr] | None:
  36. mirror = mirror_rel_op(type(expr))
  37. # Ignore unsupported expressions:
  38. # - Those that are not relational operations
  39. # - Those that don't have a mirror (just avoiding unexpected classes)
  40. if not isinstance(expr, sympy.Rel) or mirror is None:
  41. log.debug("expression with unsupported type: %s", type(expr))
  42. return None
  43. lhs_has_thing = expr.lhs.has(thing)
  44. rhs_has_thing = expr.rhs.has(thing)
  45. # Give up when 'thing' appears on both sides of the relational expression.
  46. # That is because, as is, we assume the thing we are trying to isolate is
  47. # only on the right-hand side.
  48. if lhs_has_thing and rhs_has_thing:
  49. log.debug("thing (%s) found in both sides of expression: %s", thing, expr)
  50. return None
  51. # Try considering both LHS and RHS by mirroring the original expression:
  52. # a < b ==> b > a
  53. expressions = []
  54. # Add each version of 'expr' if 'thing' is in its left-hand side.
  55. if lhs_has_thing:
  56. expressions.append(expr)
  57. if rhs_has_thing:
  58. expressions.append(mirror(expr.rhs, expr.lhs))
  59. for e in expressions:
  60. if e is None:
  61. continue
  62. if not isinstance(e, sympy.Rel):
  63. raise AssertionError("expected sympy.Rel")
  64. for _ in range(trials):
  65. trial = _try_isolate_lhs(e, thing, floordiv_inequality=floordiv_inequality)
  66. # Stop if there was no change in this trial.
  67. if trial == e:
  68. break
  69. e = trial # type: ignore[assignment]
  70. # Return if we were able to isolate 'thing' on the left-hand side.
  71. if isinstance(e, sympy.Rel) and e.lhs == thing:
  72. log.debug("solved: %s ---> %s", expr, e)
  73. return e, e.rhs
  74. return None
  75. def _try_isolate_lhs(
  76. e: sympy.Basic, thing: sympy.Basic, floordiv_inequality: bool
  77. ) -> sympy.Basic:
  78. op = type(e)
  79. if isinstance(e, sympy.Rel):
  80. # Move any constants in the left-hand side to the right-hand side.
  81. lhs_not_thing = (
  82. sum(a for a in e.lhs.args if not a.has(thing))
  83. if isinstance(e.lhs, sympy.Add)
  84. else 0
  85. )
  86. e = op(e.lhs - lhs_not_thing, e.rhs - lhs_not_thing) # type: ignore[attr-defined]
  87. # Divide both sides by the factors that don't contain thing.
  88. if isinstance(e, sympy.Rel) and isinstance(e.lhs, sympy.Mul):
  89. lhs, rhs = e.args
  90. other = sympy.Mul(*[a for a in lhs.args if not a.has(thing)])
  91. # If we can't tell whether 'other' is negative or positive, we do nothing.
  92. # That is because we don't know whether we have mirror the operation or not.
  93. # We also divide only when we know 'rhs' is not zero.
  94. if not (isinstance(e, INEQUALITY_TYPES) and other.is_negative is None) and not (
  95. not isinstance(e, INEQUALITY_TYPES) and rhs.is_zero
  96. ):
  97. # Divide both sides by 'other'.
  98. lhs = lhs / other
  99. rhs = rhs / other
  100. # If 'e' is an inequality and 'other' is negative, we have to
  101. # mirror the expression.
  102. if isinstance(e, INEQUALITY_TYPES) and other.is_negative:
  103. op = mirror_rel_op(op) # type: ignore[assignment]
  104. if op is None:
  105. raise AssertionError("expected op to be not None")
  106. e = op(lhs, rhs)
  107. ################################################################################
  108. # left-hand side is FloorDiv
  109. ################################################################################
  110. #
  111. # Given the expression: a // b op c
  112. # where 'op' is a relational operation, these rules only work if:
  113. # - b > 0
  114. # - c is an integer
  115. if (
  116. floordiv_inequality
  117. and isinstance(e, sympy.Rel)
  118. and isinstance(e.lhs, FloorDiv)
  119. and e.lhs.divisor.is_positive
  120. and e.rhs.is_integer
  121. ):
  122. # a // b == expr
  123. # => a >= (b * expr) and a < (b * (expr + 1))
  124. if isinstance(e, sympy.Eq):
  125. numerator, denominator = e.lhs.args
  126. return sympy.And(
  127. sympy.Ge(numerator, (e.rhs * denominator)),
  128. sympy.Lt(numerator, ((e.rhs + 1) * denominator)),
  129. )
  130. # a // b != expr
  131. # => a < (b * expr) or a >= (b * (expr + 1))
  132. if isinstance(e, sympy.Ne):
  133. numerator, denominator = e.lhs.args
  134. return sympy.Or(
  135. sympy.Lt(numerator, (e.rhs * denominator)),
  136. sympy.Ge(numerator, ((e.rhs + 1) * denominator)),
  137. )
  138. # The transformations below only work if b is positive.
  139. # Note: we only have this information for constants.
  140. # a // b > expr => a >= b * (expr + 1)
  141. # a // b >= expr => a >= b * expr
  142. if isinstance(e, (sympy.Gt, sympy.Ge)):
  143. quotient = e.rhs if isinstance(e, sympy.Ge) else (e.rhs + 1)
  144. return sympy.Ge(e.lhs.args[0], (quotient * e.lhs.args[1]))
  145. # a // b < expr => a < b * expr
  146. # a // b <= expr => a < b * (expr + 1)
  147. if isinstance(e, (sympy.Lt, sympy.Le)):
  148. quotient = e.rhs if isinstance(e, sympy.Lt) else (e.rhs + 1)
  149. return sympy.Lt(e.lhs.args[0], (quotient * e.lhs.args[1]))
  150. return e