modularinteger.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. """Implementation of :class:`ModularInteger` class. """
  2. from __future__ import annotations
  3. from typing import Any
  4. import operator
  5. from sympy.polys.polyutils import PicklableWithSlots
  6. from sympy.polys.polyerrors import CoercionFailed
  7. from sympy.polys.domains.domainelement import DomainElement
  8. from sympy.utilities import public
  9. from sympy.utilities.exceptions import sympy_deprecation_warning
  10. @public
  11. class ModularInteger(PicklableWithSlots, DomainElement):
  12. """A class representing a modular integer. """
  13. mod, dom, sym, _parent = None, None, None, None
  14. __slots__ = ('val',)
  15. def parent(self):
  16. return self._parent
  17. def __init__(self, val):
  18. if isinstance(val, self.__class__):
  19. self.val = val.val % self.mod
  20. else:
  21. self.val = self.dom.convert(val) % self.mod
  22. def modulus(self):
  23. return self.mod
  24. def __hash__(self):
  25. return hash((self.val, self.mod))
  26. def __repr__(self):
  27. return "%s(%s)" % (self.__class__.__name__, self.val)
  28. def __str__(self):
  29. return "%s mod %s" % (self.val, self.mod)
  30. def __int__(self):
  31. return int(self.val)
  32. def to_int(self):
  33. sympy_deprecation_warning(
  34. """ModularInteger.to_int() is deprecated.
  35. Use int(a) or K = GF(p) and K.to_int(a) instead of a.to_int().
  36. """,
  37. deprecated_since_version="1.13",
  38. active_deprecations_target="modularinteger-to-int",
  39. )
  40. if self.sym:
  41. if self.val <= self.mod // 2:
  42. return self.val
  43. else:
  44. return self.val - self.mod
  45. else:
  46. return self.val
  47. def __pos__(self):
  48. return self
  49. def __neg__(self):
  50. return self.__class__(-self.val)
  51. @classmethod
  52. def _get_val(cls, other):
  53. if isinstance(other, cls):
  54. return other.val
  55. else:
  56. try:
  57. return cls.dom.convert(other)
  58. except CoercionFailed:
  59. return None
  60. def __add__(self, other):
  61. val = self._get_val(other)
  62. if val is not None:
  63. return self.__class__(self.val + val)
  64. else:
  65. return NotImplemented
  66. def __radd__(self, other):
  67. return self.__add__(other)
  68. def __sub__(self, other):
  69. val = self._get_val(other)
  70. if val is not None:
  71. return self.__class__(self.val - val)
  72. else:
  73. return NotImplemented
  74. def __rsub__(self, other):
  75. return (-self).__add__(other)
  76. def __mul__(self, other):
  77. val = self._get_val(other)
  78. if val is not None:
  79. return self.__class__(self.val * val)
  80. else:
  81. return NotImplemented
  82. def __rmul__(self, other):
  83. return self.__mul__(other)
  84. def __truediv__(self, other):
  85. val = self._get_val(other)
  86. if val is not None:
  87. return self.__class__(self.val * self._invert(val))
  88. else:
  89. return NotImplemented
  90. def __rtruediv__(self, other):
  91. return self.invert().__mul__(other)
  92. def __mod__(self, other):
  93. val = self._get_val(other)
  94. if val is not None:
  95. return self.__class__(self.val % val)
  96. else:
  97. return NotImplemented
  98. def __rmod__(self, other):
  99. val = self._get_val(other)
  100. if val is not None:
  101. return self.__class__(val % self.val)
  102. else:
  103. return NotImplemented
  104. def __pow__(self, exp):
  105. if not exp:
  106. return self.__class__(self.dom.one)
  107. if exp < 0:
  108. val, exp = self.invert().val, -exp
  109. else:
  110. val = self.val
  111. return self.__class__(pow(val, int(exp), self.mod))
  112. def _compare(self, other, op):
  113. val = self._get_val(other)
  114. if val is None:
  115. return NotImplemented
  116. return op(self.val, val % self.mod)
  117. def _compare_deprecated(self, other, op):
  118. val = self._get_val(other)
  119. if val is None:
  120. return NotImplemented
  121. sympy_deprecation_warning(
  122. """Ordered comparisons with modular integers are deprecated.
  123. Use e.g. int(a) < int(b) instead of a < b.
  124. """,
  125. deprecated_since_version="1.13",
  126. active_deprecations_target="modularinteger-compare",
  127. stacklevel=4,
  128. )
  129. return op(self.val, val % self.mod)
  130. def __eq__(self, other):
  131. return self._compare(other, operator.eq)
  132. def __ne__(self, other):
  133. return self._compare(other, operator.ne)
  134. def __lt__(self, other):
  135. return self._compare_deprecated(other, operator.lt)
  136. def __le__(self, other):
  137. return self._compare_deprecated(other, operator.le)
  138. def __gt__(self, other):
  139. return self._compare_deprecated(other, operator.gt)
  140. def __ge__(self, other):
  141. return self._compare_deprecated(other, operator.ge)
  142. def __bool__(self):
  143. return bool(self.val)
  144. @classmethod
  145. def _invert(cls, value):
  146. return cls.dom.invert(value, cls.mod)
  147. def invert(self):
  148. return self.__class__(self._invert(self.val))
  149. _modular_integer_cache: dict[tuple[Any, Any, Any], type[ModularInteger]] = {}
  150. def ModularIntegerFactory(_mod, _dom, _sym, parent):
  151. """Create custom class for specific integer modulus."""
  152. try:
  153. _mod = _dom.convert(_mod)
  154. except CoercionFailed:
  155. ok = False
  156. else:
  157. ok = True
  158. if not ok or _mod < 1:
  159. raise ValueError("modulus must be a positive integer, got %s" % _mod)
  160. key = _mod, _dom, _sym
  161. try:
  162. cls = _modular_integer_cache[key]
  163. except KeyError:
  164. class cls(ModularInteger):
  165. mod, dom, sym = _mod, _dom, _sym
  166. _parent = parent
  167. if _sym:
  168. cls.__name__ = "SymmetricModularIntegerMod%s" % _mod
  169. else:
  170. cls.__name__ = "ModularIntegerMod%s" % _mod
  171. _modular_integer_cache[key] = cls
  172. return cls