modular.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. from math import prod
  2. from sympy.external.gmpy import gcd, gcdext
  3. from sympy.ntheory.primetest import isprime
  4. from sympy.polys.domains import ZZ
  5. from sympy.polys.galoistools import gf_crt, gf_crt1, gf_crt2
  6. from sympy.utilities.misc import as_int
  7. def symmetric_residue(a, m):
  8. """Return the residual mod m such that it is within half of the modulus.
  9. >>> from sympy.ntheory.modular import symmetric_residue
  10. >>> symmetric_residue(1, 6)
  11. 1
  12. >>> symmetric_residue(4, 6)
  13. -2
  14. """
  15. if a <= m // 2:
  16. return a
  17. return a - m
  18. def crt(m, v, symmetric=False, check=True):
  19. r"""Chinese Remainder Theorem.
  20. The moduli in m are assumed to be pairwise coprime. The output
  21. is then an integer f, such that f = v_i mod m_i for each pair out
  22. of v and m. If ``symmetric`` is False a positive integer will be
  23. returned, else \|f\| will be less than or equal to the LCM of the
  24. moduli, and thus f may be negative.
  25. If the moduli are not co-prime the correct result will be returned
  26. if/when the test of the result is found to be incorrect. This result
  27. will be None if there is no solution.
  28. The keyword ``check`` can be set to False if it is known that the moduli
  29. are coprime.
  30. Examples
  31. ========
  32. As an example consider a set of residues ``U = [49, 76, 65]``
  33. and a set of moduli ``M = [99, 97, 95]``. Then we have::
  34. >>> from sympy.ntheory.modular import crt
  35. >>> crt([99, 97, 95], [49, 76, 65])
  36. (639985, 912285)
  37. This is the correct result because::
  38. >>> [639985 % m for m in [99, 97, 95]]
  39. [49, 76, 65]
  40. If the moduli are not co-prime, you may receive an incorrect result
  41. if you use ``check=False``:
  42. >>> crt([12, 6, 17], [3, 4, 2], check=False)
  43. (954, 1224)
  44. >>> [954 % m for m in [12, 6, 17]]
  45. [6, 0, 2]
  46. >>> crt([12, 6, 17], [3, 4, 2]) is None
  47. True
  48. >>> crt([3, 6], [2, 5])
  49. (5, 6)
  50. Note: the order of gf_crt's arguments is reversed relative to crt,
  51. and that solve_congruence takes residue, modulus pairs.
  52. Programmer's note: rather than checking that all pairs of moduli share
  53. no GCD (an O(n**2) test) and rather than factoring all moduli and seeing
  54. that there is no factor in common, a check that the result gives the
  55. indicated residuals is performed -- an O(n) operation.
  56. See Also
  57. ========
  58. solve_congruence
  59. sympy.polys.galoistools.gf_crt : low level crt routine used by this routine
  60. """
  61. if check:
  62. m = list(map(as_int, m))
  63. v = list(map(as_int, v))
  64. result = gf_crt(v, m, ZZ)
  65. mm = prod(m)
  66. if check:
  67. if not all(v % m == result % m for v, m in zip(v, m)):
  68. result = solve_congruence(*list(zip(v, m)),
  69. check=False, symmetric=symmetric)
  70. if result is None:
  71. return result
  72. result, mm = result
  73. if symmetric:
  74. return int(symmetric_residue(result, mm)), int(mm)
  75. return int(result), int(mm)
  76. def crt1(m):
  77. """First part of Chinese Remainder Theorem, for multiple application.
  78. Examples
  79. ========
  80. >>> from sympy.ntheory.modular import crt, crt1, crt2
  81. >>> m = [99, 97, 95]
  82. >>> v = [49, 76, 65]
  83. The following two codes have the same result.
  84. >>> crt(m, v)
  85. (639985, 912285)
  86. >>> mm, e, s = crt1(m)
  87. >>> crt2(m, v, mm, e, s)
  88. (639985, 912285)
  89. However, it is faster when we want to fix ``m`` and
  90. compute for multiple ``v``, i.e. the following cases:
  91. >>> mm, e, s = crt1(m)
  92. >>> vs = [[52, 21, 37], [19, 46, 76]]
  93. >>> for v in vs:
  94. ... print(crt2(m, v, mm, e, s))
  95. (397042, 912285)
  96. (803206, 912285)
  97. See Also
  98. ========
  99. sympy.polys.galoistools.gf_crt1 : low level crt routine used by this routine
  100. sympy.ntheory.modular.crt
  101. sympy.ntheory.modular.crt2
  102. """
  103. return gf_crt1(m, ZZ)
  104. def crt2(m, v, mm, e, s, symmetric=False):
  105. """Second part of Chinese Remainder Theorem, for multiple application.
  106. See ``crt1`` for usage.
  107. Examples
  108. ========
  109. >>> from sympy.ntheory.modular import crt1, crt2
  110. >>> mm, e, s = crt1([18, 42, 6])
  111. >>> crt2([18, 42, 6], [0, 0, 0], mm, e, s)
  112. (0, 4536)
  113. See Also
  114. ========
  115. sympy.polys.galoistools.gf_crt2 : low level crt routine used by this routine
  116. sympy.ntheory.modular.crt
  117. sympy.ntheory.modular.crt1
  118. """
  119. result = gf_crt2(v, m, mm, e, s, ZZ)
  120. if symmetric:
  121. return int(symmetric_residue(result, mm)), int(mm)
  122. return int(result), int(mm)
  123. def solve_congruence(*remainder_modulus_pairs, **hint):
  124. """Compute the integer ``n`` that has the residual ``ai`` when it is
  125. divided by ``mi`` where the ``ai`` and ``mi`` are given as pairs to
  126. this function: ((a1, m1), (a2, m2), ...). If there is no solution,
  127. return None. Otherwise return ``n`` and its modulus.
  128. The ``mi`` values need not be co-prime. If it is known that the moduli are
  129. not co-prime then the hint ``check`` can be set to False (default=True) and
  130. the check for a quicker solution via crt() (valid when the moduli are
  131. co-prime) will be skipped.
  132. If the hint ``symmetric`` is True (default is False), the value of ``n``
  133. will be within 1/2 of the modulus, possibly negative.
  134. Examples
  135. ========
  136. >>> from sympy.ntheory.modular import solve_congruence
  137. What number is 2 mod 3, 3 mod 5 and 2 mod 7?
  138. >>> solve_congruence((2, 3), (3, 5), (2, 7))
  139. (23, 105)
  140. >>> [23 % m for m in [3, 5, 7]]
  141. [2, 3, 2]
  142. If you prefer to work with all remainder in one list and
  143. all moduli in another, send the arguments like this:
  144. >>> solve_congruence(*zip((2, 3, 2), (3, 5, 7)))
  145. (23, 105)
  146. The moduli need not be co-prime; in this case there may or
  147. may not be a solution:
  148. >>> solve_congruence((2, 3), (4, 6)) is None
  149. True
  150. >>> solve_congruence((2, 3), (5, 6))
  151. (5, 6)
  152. The symmetric flag will make the result be within 1/2 of the modulus:
  153. >>> solve_congruence((2, 3), (5, 6), symmetric=True)
  154. (-1, 6)
  155. See Also
  156. ========
  157. crt : high level routine implementing the Chinese Remainder Theorem
  158. """
  159. def combine(c1, c2):
  160. """Return the tuple (a, m) which satisfies the requirement
  161. that n = a + i*m satisfy n = a1 + j*m1 and n = a2 = k*m2.
  162. References
  163. ==========
  164. .. [1] https://en.wikipedia.org/wiki/Method_of_successive_substitution
  165. """
  166. a1, m1 = c1
  167. a2, m2 = c2
  168. a, b, c = m1, a2 - a1, m2
  169. g = gcd(a, b, c)
  170. a, b, c = [i//g for i in [a, b, c]]
  171. if a != 1:
  172. g, inv_a, _ = gcdext(a, c)
  173. if g != 1:
  174. return None
  175. b *= inv_a
  176. a, m = a1 + m1*b, m1*c
  177. return a, m
  178. rm = remainder_modulus_pairs
  179. symmetric = hint.get('symmetric', False)
  180. if hint.get('check', True):
  181. rm = [(as_int(r), as_int(m)) for r, m in rm]
  182. # ignore redundant pairs but raise an error otherwise; also
  183. # make sure that a unique set of bases is sent to gf_crt if
  184. # they are all prime.
  185. #
  186. # The routine will work out less-trivial violations and
  187. # return None, e.g. for the pairs (1,3) and (14,42) there
  188. # is no answer because 14 mod 42 (having a gcd of 14) implies
  189. # (14/2) mod (42/2), (14/7) mod (42/7) and (14/14) mod (42/14)
  190. # which, being 0 mod 3, is inconsistent with 1 mod 3. But to
  191. # preprocess the input beyond checking of another pair with 42
  192. # or 3 as the modulus (for this example) is not necessary.
  193. uniq = {}
  194. for r, m in rm:
  195. r %= m
  196. if m in uniq:
  197. if r != uniq[m]:
  198. return None
  199. continue
  200. uniq[m] = r
  201. rm = [(r, m) for m, r in uniq.items()]
  202. del uniq
  203. # if the moduli are co-prime, the crt will be significantly faster;
  204. # checking all pairs for being co-prime gets to be slow but a prime
  205. # test is a good trade-off
  206. if all(isprime(m) for r, m in rm):
  207. r, m = list(zip(*rm))
  208. return crt(m, r, symmetric=symmetric, check=False)
  209. rv = (0, 1)
  210. for rmi in rm:
  211. rv = combine(rv, rmi)
  212. if rv is None:
  213. break
  214. n, m = rv
  215. n = n % m
  216. else:
  217. if symmetric:
  218. return symmetric_residue(n, m), m
  219. return n, m