mod.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. from .add import Add
  2. from .exprtools import gcd_terms
  3. from .function import DefinedFunction
  4. from .kind import NumberKind
  5. from .logic import fuzzy_and, fuzzy_not
  6. from .mul import Mul
  7. from .numbers import equal_valued
  8. from .relational import is_le, is_lt, is_ge, is_gt
  9. from .singleton import S
  10. class Mod(DefinedFunction):
  11. """Represents a modulo operation on symbolic expressions.
  12. Parameters
  13. ==========
  14. p : Expr
  15. Dividend.
  16. q : Expr
  17. Divisor.
  18. Notes
  19. =====
  20. The convention used is the same as Python's: the remainder always has the
  21. same sign as the divisor.
  22. Many objects can be evaluated modulo ``n`` much faster than they can be
  23. evaluated directly (or at all). For this, ``evaluate=False`` is
  24. necessary to prevent eager evaluation:
  25. >>> from sympy import binomial, factorial, Mod, Pow
  26. >>> Mod(Pow(2, 10**16, evaluate=False), 97)
  27. 61
  28. >>> Mod(factorial(10**9, evaluate=False), 10**9 + 9)
  29. 712524808
  30. >>> Mod(binomial(10**18, 10**12, evaluate=False), (10**5 + 3)**2)
  31. 3744312326
  32. Examples
  33. ========
  34. >>> from sympy.abc import x, y
  35. >>> x**2 % y
  36. Mod(x**2, y)
  37. >>> _.subs({x: 5, y: 6})
  38. 1
  39. """
  40. kind = NumberKind
  41. @classmethod
  42. def eval(cls, p, q):
  43. def number_eval(p, q):
  44. """Try to return p % q if both are numbers or +/-p is known
  45. to be less than or equal q.
  46. """
  47. if q.is_zero:
  48. raise ZeroDivisionError("Modulo by zero")
  49. if p is S.NaN or q is S.NaN or p.is_finite is False or q.is_finite is False:
  50. return S.NaN
  51. if p is S.Zero or p in (q, -q) or (p.is_integer and q == 1):
  52. return S.Zero
  53. if q.is_Number:
  54. if p.is_Number:
  55. return p%q
  56. if q == 2:
  57. if p.is_even:
  58. return S.Zero
  59. elif p.is_odd:
  60. return S.One
  61. if hasattr(p, '_eval_Mod'):
  62. rv = getattr(p, '_eval_Mod')(q)
  63. if rv is not None:
  64. return rv
  65. # by ratio
  66. r = p/q
  67. if r.is_integer:
  68. return S.Zero
  69. try:
  70. d = int(r)
  71. except TypeError:
  72. pass
  73. else:
  74. if isinstance(d, int):
  75. rv = p - d*q
  76. if (rv*q < 0) == True:
  77. rv += q
  78. return rv
  79. # by difference
  80. # -2|q| < p < 2|q|
  81. if q.is_positive:
  82. comp1, comp2 = is_le, is_lt
  83. elif q.is_negative:
  84. comp1, comp2 = is_ge, is_gt
  85. else:
  86. return
  87. ls = -2*q
  88. r = p - q
  89. for _ in range(4):
  90. if not comp1(ls, p):
  91. return
  92. if comp2(r, ls):
  93. return p - ls
  94. ls += q
  95. rv = number_eval(p, q)
  96. if rv is not None:
  97. return rv
  98. # denest
  99. if isinstance(p, cls):
  100. qinner = p.args[1]
  101. if qinner % q == 0:
  102. return cls(p.args[0], q)
  103. elif (qinner*(q - qinner)).is_nonnegative:
  104. # |qinner| < |q| and have same sign
  105. return p
  106. elif isinstance(-p, cls):
  107. qinner = (-p).args[1]
  108. if qinner % q == 0:
  109. return cls(-(-p).args[0], q)
  110. elif (qinner*(q + qinner)).is_nonpositive:
  111. # |qinner| < |q| and have different sign
  112. return p
  113. elif isinstance(p, Add):
  114. # separating into modulus and non modulus
  115. both_l = non_mod_l, mod_l = [], []
  116. for arg in p.args:
  117. both_l[isinstance(arg, cls)].append(arg)
  118. # if q same for all
  119. if mod_l and all(inner.args[1] == q for inner in mod_l):
  120. net = Add(*non_mod_l) + Add(*[i.args[0] for i in mod_l])
  121. return cls(net, q)
  122. elif isinstance(p, Mul):
  123. # separating into modulus and non modulus
  124. both_l = non_mod_l, mod_l = [], []
  125. for arg in p.args:
  126. both_l[isinstance(arg, cls)].append(arg)
  127. if mod_l and all(inner.args[1] == q for inner in mod_l) and all(t.is_integer for t in p.args) and q.is_integer:
  128. # finding distributive term
  129. non_mod_l = [cls(x, q) for x in non_mod_l]
  130. mod = []
  131. non_mod = []
  132. for j in non_mod_l:
  133. if isinstance(j, cls):
  134. mod.append(j.args[0])
  135. else:
  136. non_mod.append(j)
  137. prod_mod = Mul(*mod)
  138. prod_non_mod = Mul(*non_mod)
  139. prod_mod1 = Mul(*[i.args[0] for i in mod_l])
  140. net = prod_mod1*prod_mod
  141. return prod_non_mod*cls(net, q)
  142. if q.is_Integer and q is not S.One:
  143. if all(t.is_integer for t in p.args):
  144. non_mod_l = [i % q if i.is_Integer else i for i in p.args]
  145. if any(iq is S.Zero for iq in non_mod_l):
  146. return S.Zero
  147. p = Mul(*(non_mod_l + mod_l))
  148. # XXX other possibilities?
  149. from sympy.polys.polyerrors import PolynomialError
  150. from sympy.polys.polytools import gcd
  151. # extract gcd; any further simplification should be done by the user
  152. try:
  153. G = gcd(p, q)
  154. if not equal_valued(G, 1):
  155. p, q = [gcd_terms(i/G, clear=False, fraction=False)
  156. for i in (p, q)]
  157. except PolynomialError: # issue 21373
  158. G = S.One
  159. pwas, qwas = p, q
  160. # simplify terms
  161. # (x + y + 2) % x -> Mod(y + 2, x)
  162. if p.is_Add:
  163. args = []
  164. for i in p.args:
  165. a = cls(i, q)
  166. if a.count(cls) > i.count(cls):
  167. args.append(i)
  168. else:
  169. args.append(a)
  170. if args != list(p.args):
  171. p = Add(*args)
  172. else:
  173. # handle coefficients if they are not Rational
  174. # since those are not handled by factor_terms
  175. # e.g. Mod(.6*x, .3*y) -> 0.3*Mod(2*x, y)
  176. cp, p = p.as_coeff_Mul()
  177. cq, q = q.as_coeff_Mul()
  178. ok = False
  179. if not cp.is_Rational or not cq.is_Rational:
  180. r = cp % cq
  181. if equal_valued(r, 0):
  182. G *= cq
  183. p *= int(cp/cq)
  184. ok = True
  185. if not ok:
  186. p = cp*p
  187. q = cq*q
  188. # simple -1 extraction
  189. if p.could_extract_minus_sign() and q.could_extract_minus_sign():
  190. G, p, q = [-i for i in (G, p, q)]
  191. # check again to see if p and q can now be handled as numbers
  192. rv = number_eval(p, q)
  193. if rv is not None:
  194. return rv*G
  195. # put 1.0 from G on inside
  196. if G.is_Float and equal_valued(G, 1):
  197. p *= G
  198. return cls(p, q, evaluate=False)
  199. elif G.is_Mul and G.args[0].is_Float and equal_valued(G.args[0], 1):
  200. p = G.args[0]*p
  201. G = Mul._from_args(G.args[1:])
  202. return G*cls(p, q, evaluate=(p, q) != (pwas, qwas))
  203. def _eval_is_integer(self):
  204. p, q = self.args
  205. if fuzzy_and([p.is_integer, q.is_integer, fuzzy_not(q.is_zero)]):
  206. return True
  207. def _eval_is_nonnegative(self):
  208. if self.args[1].is_positive:
  209. return True
  210. def _eval_is_nonpositive(self):
  211. if self.args[1].is_negative:
  212. return True
  213. def _eval_rewrite_as_floor(self, a, b, **kwargs):
  214. from sympy.functions.elementary.integers import floor
  215. return a - b*floor(a/b)
  216. def _eval_as_leading_term(self, x, logx, cdir):
  217. from sympy.functions.elementary.integers import floor
  218. return self.rewrite(floor)._eval_as_leading_term(x, logx=logx, cdir=cdir)
  219. def _eval_nseries(self, x, n, logx, cdir=0):
  220. from sympy.functions.elementary.integers import floor
  221. return self.rewrite(floor)._eval_nseries(x, n, logx=logx, cdir=cdir)