recurrence.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. """Recurrence Operators"""
  2. from sympy.core.singleton import S
  3. from sympy.core.symbol import (Symbol, symbols)
  4. from sympy.printing import sstr
  5. from sympy.core.sympify import sympify
  6. def RecurrenceOperators(base, generator):
  7. """
  8. Returns an Algebra of Recurrence Operators and the operator for
  9. shifting i.e. the `Sn` operator.
  10. The first argument needs to be the base polynomial ring for the algebra
  11. and the second argument must be a generator which can be either a
  12. noncommutative Symbol or a string.
  13. Examples
  14. ========
  15. >>> from sympy import ZZ
  16. >>> from sympy import symbols
  17. >>> from sympy.holonomic.recurrence import RecurrenceOperators
  18. >>> n = symbols('n', integer=True)
  19. >>> R, Sn = RecurrenceOperators(ZZ.old_poly_ring(n), 'Sn')
  20. """
  21. ring = RecurrenceOperatorAlgebra(base, generator)
  22. return (ring, ring.shift_operator)
  23. class RecurrenceOperatorAlgebra:
  24. """
  25. A Recurrence Operator Algebra is a set of noncommutative polynomials
  26. in intermediate `Sn` and coefficients in a base ring A. It follows the
  27. commutation rule:
  28. Sn * a(n) = a(n + 1) * Sn
  29. This class represents a Recurrence Operator Algebra and serves as the parent ring
  30. for Recurrence Operators.
  31. Examples
  32. ========
  33. >>> from sympy import ZZ
  34. >>> from sympy import symbols
  35. >>> from sympy.holonomic.recurrence import RecurrenceOperators
  36. >>> n = symbols('n', integer=True)
  37. >>> R, Sn = RecurrenceOperators(ZZ.old_poly_ring(n), 'Sn')
  38. >>> R
  39. Univariate Recurrence Operator Algebra in intermediate Sn over the base ring
  40. ZZ[n]
  41. See Also
  42. ========
  43. RecurrenceOperator
  44. """
  45. def __init__(self, base, generator):
  46. # the base ring for the algebra
  47. self.base = base
  48. # the operator representing shift i.e. `Sn`
  49. self.shift_operator = RecurrenceOperator(
  50. [base.zero, base.one], self)
  51. if generator is None:
  52. self.gen_symbol = symbols('Sn', commutative=False)
  53. else:
  54. if isinstance(generator, str):
  55. self.gen_symbol = symbols(generator, commutative=False)
  56. elif isinstance(generator, Symbol):
  57. self.gen_symbol = generator
  58. def __str__(self):
  59. string = 'Univariate Recurrence Operator Algebra in intermediate '\
  60. + sstr(self.gen_symbol) + ' over the base ring ' + \
  61. (self.base).__str__()
  62. return string
  63. __repr__ = __str__
  64. def __eq__(self, other):
  65. if self.base == other.base and self.gen_symbol == other.gen_symbol:
  66. return True
  67. else:
  68. return False
  69. def _add_lists(list1, list2):
  70. if len(list1) <= len(list2):
  71. sol = [a + b for a, b in zip(list1, list2)] + list2[len(list1):]
  72. else:
  73. sol = [a + b for a, b in zip(list1, list2)] + list1[len(list2):]
  74. return sol
  75. class RecurrenceOperator:
  76. """
  77. The Recurrence Operators are defined by a list of polynomials
  78. in the base ring and the parent ring of the Operator.
  79. Explanation
  80. ===========
  81. Takes a list of polynomials for each power of Sn and the
  82. parent ring which must be an instance of RecurrenceOperatorAlgebra.
  83. A Recurrence Operator can be created easily using
  84. the operator `Sn`. See examples below.
  85. Examples
  86. ========
  87. >>> from sympy.holonomic.recurrence import RecurrenceOperator, RecurrenceOperators
  88. >>> from sympy import ZZ
  89. >>> from sympy import symbols
  90. >>> n = symbols('n', integer=True)
  91. >>> R, Sn = RecurrenceOperators(ZZ.old_poly_ring(n),'Sn')
  92. >>> RecurrenceOperator([0, 1, n**2], R)
  93. (1)Sn + (n**2)Sn**2
  94. >>> Sn*n
  95. (n + 1)Sn
  96. >>> n*Sn*n + 1 - Sn**2*n
  97. (1) + (n**2 + n)Sn + (-n - 2)Sn**2
  98. See Also
  99. ========
  100. DifferentialOperatorAlgebra
  101. """
  102. _op_priority = 20
  103. def __init__(self, list_of_poly, parent):
  104. # the parent ring for this operator
  105. # must be an RecurrenceOperatorAlgebra object
  106. self.parent = parent
  107. # sequence of polynomials in n for each power of Sn
  108. # represents the operator
  109. # convert the expressions into ring elements using from_sympy
  110. if isinstance(list_of_poly, list):
  111. for i, j in enumerate(list_of_poly):
  112. if isinstance(j, int):
  113. list_of_poly[i] = self.parent.base.from_sympy(S(j))
  114. elif not isinstance(j, self.parent.base.dtype):
  115. list_of_poly[i] = self.parent.base.from_sympy(j)
  116. self.listofpoly = list_of_poly
  117. self.order = len(self.listofpoly) - 1
  118. def __mul__(self, other):
  119. """
  120. Multiplies two Operators and returns another
  121. RecurrenceOperator instance using the commutation rule
  122. Sn * a(n) = a(n + 1) * Sn
  123. """
  124. listofself = self.listofpoly
  125. base = self.parent.base
  126. if not isinstance(other, RecurrenceOperator):
  127. if not isinstance(other, self.parent.base.dtype):
  128. listofother = [self.parent.base.from_sympy(sympify(other))]
  129. else:
  130. listofother = [other]
  131. else:
  132. listofother = other.listofpoly
  133. # multiply a polynomial `b` with a list of polynomials
  134. def _mul_dmp_diffop(b, listofother):
  135. if isinstance(listofother, list):
  136. return [i * b for i in listofother]
  137. return [b * listofother]
  138. sol = _mul_dmp_diffop(listofself[0], listofother)
  139. # compute Sn^i * b
  140. def _mul_Sni_b(b):
  141. sol = [base.zero]
  142. if isinstance(b, list):
  143. for i in b:
  144. j = base.to_sympy(i).subs(base.gens[0], base.gens[0] + S.One)
  145. sol.append(base.from_sympy(j))
  146. else:
  147. j = b.subs(base.gens[0], base.gens[0] + S.One)
  148. sol.append(base.from_sympy(j))
  149. return sol
  150. for i in range(1, len(listofself)):
  151. # find Sn^i * b in ith iteration
  152. listofother = _mul_Sni_b(listofother)
  153. # solution = solution + listofself[i] * (Sn^i * b)
  154. sol = _add_lists(sol, _mul_dmp_diffop(listofself[i], listofother))
  155. return RecurrenceOperator(sol, self.parent)
  156. def __rmul__(self, other):
  157. if not isinstance(other, RecurrenceOperator):
  158. if isinstance(other, int):
  159. other = S(other)
  160. if not isinstance(other, self.parent.base.dtype):
  161. other = (self.parent.base).from_sympy(other)
  162. sol = [other * j for j in self.listofpoly]
  163. return RecurrenceOperator(sol, self.parent)
  164. def __add__(self, other):
  165. if isinstance(other, RecurrenceOperator):
  166. sol = _add_lists(self.listofpoly, other.listofpoly)
  167. return RecurrenceOperator(sol, self.parent)
  168. else:
  169. if isinstance(other, int):
  170. other = S(other)
  171. list_self = self.listofpoly
  172. if not isinstance(other, self.parent.base.dtype):
  173. list_other = [((self.parent).base).from_sympy(other)]
  174. else:
  175. list_other = [other]
  176. sol = [list_self[0] + list_other[0]] + list_self[1:]
  177. return RecurrenceOperator(sol, self.parent)
  178. __radd__ = __add__
  179. def __sub__(self, other):
  180. return self + (-1) * other
  181. def __rsub__(self, other):
  182. return (-1) * self + other
  183. def __pow__(self, n):
  184. if n == 1:
  185. return self
  186. result = RecurrenceOperator([self.parent.base.one], self.parent)
  187. if n == 0:
  188. return result
  189. # if self is `Sn`
  190. if self.listofpoly == self.parent.shift_operator.listofpoly:
  191. sol = [self.parent.base.zero] * n + [self.parent.base.one]
  192. return RecurrenceOperator(sol, self.parent)
  193. x = self
  194. while True:
  195. if n % 2:
  196. result *= x
  197. n >>= 1
  198. if not n:
  199. break
  200. x *= x
  201. return result
  202. def __str__(self):
  203. listofpoly = self.listofpoly
  204. print_str = ''
  205. for i, j in enumerate(listofpoly):
  206. if j == self.parent.base.zero:
  207. continue
  208. j = self.parent.base.to_sympy(j)
  209. if i == 0:
  210. print_str += '(' + sstr(j) + ')'
  211. continue
  212. if print_str:
  213. print_str += ' + '
  214. if i == 1:
  215. print_str += '(' + sstr(j) + ')Sn'
  216. continue
  217. print_str += '(' + sstr(j) + ')' + 'Sn**' + sstr(i)
  218. return print_str
  219. __repr__ = __str__
  220. def __eq__(self, other):
  221. if isinstance(other, RecurrenceOperator):
  222. if self.listofpoly == other.listofpoly and self.parent == other.parent:
  223. return True
  224. else:
  225. return False
  226. return self.listofpoly[0] == other and \
  227. all(i is self.parent.base.zero for i in self.listofpoly[1:])
  228. class HolonomicSequence:
  229. """
  230. A Holonomic Sequence is a type of sequence satisfying a linear homogeneous
  231. recurrence relation with Polynomial coefficients. Alternatively, A sequence
  232. is Holonomic if and only if its generating function is a Holonomic Function.
  233. """
  234. def __init__(self, recurrence, u0=[]):
  235. self.recurrence = recurrence
  236. if not isinstance(u0, list):
  237. self.u0 = [u0]
  238. else:
  239. self.u0 = u0
  240. if len(self.u0) == 0:
  241. self._have_init_cond = False
  242. else:
  243. self._have_init_cond = True
  244. self.n = recurrence.parent.base.gens[0]
  245. def __repr__(self):
  246. str_sol = 'HolonomicSequence(%s, %s)' % ((self.recurrence).__repr__(), sstr(self.n))
  247. if not self._have_init_cond:
  248. return str_sol
  249. else:
  250. cond_str = ''
  251. seq_str = 0
  252. for i in self.u0:
  253. cond_str += ', u(%s) = %s' % (sstr(seq_str), sstr(i))
  254. seq_str += 1
  255. sol = str_sol + cond_str
  256. return sol
  257. __str__ = __repr__
  258. def __eq__(self, other):
  259. if self.recurrence != other.recurrence or self.n != other.n:
  260. return False
  261. if self._have_init_cond and other._have_init_cond:
  262. return self.u0 == other.u0
  263. return True