delta.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. """
  2. This module implements sums and products containing the Kronecker Delta function.
  3. References
  4. ==========
  5. .. [1] https://mathworld.wolfram.com/KroneckerDelta.html
  6. """
  7. from .products import product
  8. from .summations import Sum, summation
  9. from sympy.core import Add, Mul, S, Dummy
  10. from sympy.core.cache import cacheit
  11. from sympy.core.sorting import default_sort_key
  12. from sympy.functions import KroneckerDelta, Piecewise, piecewise_fold
  13. from sympy.polys.polytools import factor
  14. from sympy.sets.sets import Interval
  15. from sympy.solvers.solvers import solve
  16. @cacheit
  17. def _expand_delta(expr, index):
  18. """
  19. Expand the first Add containing a simple KroneckerDelta.
  20. """
  21. if not expr.is_Mul:
  22. return expr
  23. delta = None
  24. func = Add
  25. terms = [S.One]
  26. for h in expr.args:
  27. if delta is None and h.is_Add and _has_simple_delta(h, index):
  28. delta = True
  29. func = h.func
  30. terms = [terms[0]*t for t in h.args]
  31. else:
  32. terms = [t*h for t in terms]
  33. return func(*terms)
  34. @cacheit
  35. def _extract_delta(expr, index):
  36. """
  37. Extract a simple KroneckerDelta from the expression.
  38. Explanation
  39. ===========
  40. Returns the tuple ``(delta, newexpr)`` where:
  41. - ``delta`` is a simple KroneckerDelta expression if one was found,
  42. or ``None`` if no simple KroneckerDelta expression was found.
  43. - ``newexpr`` is a Mul containing the remaining terms; ``expr`` is
  44. returned unchanged if no simple KroneckerDelta expression was found.
  45. Examples
  46. ========
  47. >>> from sympy import KroneckerDelta
  48. >>> from sympy.concrete.delta import _extract_delta
  49. >>> from sympy.abc import x, y, i, j, k
  50. >>> _extract_delta(4*x*y*KroneckerDelta(i, j), i)
  51. (KroneckerDelta(i, j), 4*x*y)
  52. >>> _extract_delta(4*x*y*KroneckerDelta(i, j), k)
  53. (None, 4*x*y*KroneckerDelta(i, j))
  54. See Also
  55. ========
  56. sympy.functions.special.tensor_functions.KroneckerDelta
  57. deltaproduct
  58. deltasummation
  59. """
  60. if not _has_simple_delta(expr, index):
  61. return (None, expr)
  62. if isinstance(expr, KroneckerDelta):
  63. return (expr, S.One)
  64. if not expr.is_Mul:
  65. raise ValueError("Incorrect expr")
  66. delta = None
  67. terms = []
  68. for arg in expr.args:
  69. if delta is None and _is_simple_delta(arg, index):
  70. delta = arg
  71. else:
  72. terms.append(arg)
  73. return (delta, expr.func(*terms))
  74. @cacheit
  75. def _has_simple_delta(expr, index):
  76. """
  77. Returns True if ``expr`` is an expression that contains a KroneckerDelta
  78. that is simple in the index ``index``, meaning that this KroneckerDelta
  79. is nonzero for a single value of the index ``index``.
  80. """
  81. if expr.has(KroneckerDelta):
  82. if _is_simple_delta(expr, index):
  83. return True
  84. if expr.is_Add or expr.is_Mul:
  85. return any(_has_simple_delta(arg, index) for arg in expr.args)
  86. return False
  87. @cacheit
  88. def _is_simple_delta(delta, index):
  89. """
  90. Returns True if ``delta`` is a KroneckerDelta and is nonzero for a single
  91. value of the index ``index``.
  92. """
  93. if isinstance(delta, KroneckerDelta) and delta.has(index):
  94. p = (delta.args[0] - delta.args[1]).as_poly(index)
  95. if p:
  96. return p.degree() == 1
  97. return False
  98. @cacheit
  99. def _remove_multiple_delta(expr):
  100. """
  101. Evaluate products of KroneckerDelta's.
  102. """
  103. if expr.is_Add:
  104. return expr.func(*list(map(_remove_multiple_delta, expr.args)))
  105. if not expr.is_Mul:
  106. return expr
  107. eqs = []
  108. newargs = []
  109. for arg in expr.args:
  110. if isinstance(arg, KroneckerDelta):
  111. eqs.append(arg.args[0] - arg.args[1])
  112. else:
  113. newargs.append(arg)
  114. if not eqs:
  115. return expr
  116. solns = solve(eqs, dict=True)
  117. if len(solns) == 0:
  118. return S.Zero
  119. elif len(solns) == 1:
  120. newargs += [KroneckerDelta(k, v) for k, v in solns[0].items()]
  121. expr2 = expr.func(*newargs)
  122. if expr != expr2:
  123. return _remove_multiple_delta(expr2)
  124. return expr
  125. @cacheit
  126. def _simplify_delta(expr):
  127. """
  128. Rewrite a KroneckerDelta's indices in its simplest form.
  129. """
  130. if isinstance(expr, KroneckerDelta):
  131. try:
  132. slns = solve(expr.args[0] - expr.args[1], dict=True)
  133. if slns and len(slns) == 1:
  134. return Mul(*[KroneckerDelta(*(key, value))
  135. for key, value in slns[0].items()])
  136. except NotImplementedError:
  137. pass
  138. return expr
  139. @cacheit
  140. def deltaproduct(f, limit):
  141. """
  142. Handle products containing a KroneckerDelta.
  143. See Also
  144. ========
  145. deltasummation
  146. sympy.functions.special.tensor_functions.KroneckerDelta
  147. sympy.concrete.products.product
  148. """
  149. if ((limit[2] - limit[1]) < 0) == True:
  150. return S.One
  151. if not f.has(KroneckerDelta):
  152. return product(f, limit)
  153. if f.is_Add:
  154. # Identify the term in the Add that has a simple KroneckerDelta
  155. delta = None
  156. terms = []
  157. for arg in sorted(f.args, key=default_sort_key):
  158. if delta is None and _has_simple_delta(arg, limit[0]):
  159. delta = arg
  160. else:
  161. terms.append(arg)
  162. newexpr = f.func(*terms)
  163. k = Dummy("kprime", integer=True)
  164. if isinstance(limit[1], int) and isinstance(limit[2], int):
  165. result = deltaproduct(newexpr, limit) + sum(deltaproduct(newexpr, (limit[0], limit[1], ik - 1)) *
  166. delta.subs(limit[0], ik) *
  167. deltaproduct(newexpr, (limit[0], ik + 1, limit[2])) for ik in range(int(limit[1]), int(limit[2] + 1))
  168. )
  169. else:
  170. result = deltaproduct(newexpr, limit) + deltasummation(
  171. deltaproduct(newexpr, (limit[0], limit[1], k - 1)) *
  172. delta.subs(limit[0], k) *
  173. deltaproduct(newexpr, (limit[0], k + 1, limit[2])),
  174. (k, limit[1], limit[2]),
  175. no_piecewise=_has_simple_delta(newexpr, limit[0])
  176. )
  177. return _remove_multiple_delta(result)
  178. delta, _ = _extract_delta(f, limit[0])
  179. if not delta:
  180. g = _expand_delta(f, limit[0])
  181. if f != g:
  182. try:
  183. return factor(deltaproduct(g, limit))
  184. except AssertionError:
  185. return deltaproduct(g, limit)
  186. return product(f, limit)
  187. return _remove_multiple_delta(f.subs(limit[0], limit[1])*KroneckerDelta(limit[2], limit[1])) + \
  188. S.One*_simplify_delta(KroneckerDelta(limit[2], limit[1] - 1))
  189. @cacheit
  190. def deltasummation(f, limit, no_piecewise=False):
  191. """
  192. Handle summations containing a KroneckerDelta.
  193. Explanation
  194. ===========
  195. The idea for summation is the following:
  196. - If we are dealing with a KroneckerDelta expression, i.e. KroneckerDelta(g(x), j),
  197. we try to simplify it.
  198. If we could simplify it, then we sum the resulting expression.
  199. We already know we can sum a simplified expression, because only
  200. simple KroneckerDelta expressions are involved.
  201. If we could not simplify it, there are two cases:
  202. 1) The expression is a simple expression: we return the summation,
  203. taking care if we are dealing with a Derivative or with a proper
  204. KroneckerDelta.
  205. 2) The expression is not simple (i.e. KroneckerDelta(cos(x))): we can do
  206. nothing at all.
  207. - If the expr is a multiplication expr having a KroneckerDelta term:
  208. First we expand it.
  209. If the expansion did work, then we try to sum the expansion.
  210. If not, we try to extract a simple KroneckerDelta term, then we have two
  211. cases:
  212. 1) We have a simple KroneckerDelta term, so we return the summation.
  213. 2) We did not have a simple term, but we do have an expression with
  214. simplified KroneckerDelta terms, so we sum this expression.
  215. Examples
  216. ========
  217. >>> from sympy import oo, symbols
  218. >>> from sympy.abc import k
  219. >>> i, j = symbols('i, j', integer=True, finite=True)
  220. >>> from sympy.concrete.delta import deltasummation
  221. >>> from sympy import KroneckerDelta
  222. >>> deltasummation(KroneckerDelta(i, k), (k, -oo, oo))
  223. 1
  224. >>> deltasummation(KroneckerDelta(i, k), (k, 0, oo))
  225. Piecewise((1, i >= 0), (0, True))
  226. >>> deltasummation(KroneckerDelta(i, k), (k, 1, 3))
  227. Piecewise((1, (i >= 1) & (i <= 3)), (0, True))
  228. >>> deltasummation(k*KroneckerDelta(i, j)*KroneckerDelta(j, k), (k, -oo, oo))
  229. j*KroneckerDelta(i, j)
  230. >>> deltasummation(j*KroneckerDelta(i, j), (j, -oo, oo))
  231. i
  232. >>> deltasummation(i*KroneckerDelta(i, j), (i, -oo, oo))
  233. j
  234. See Also
  235. ========
  236. deltaproduct
  237. sympy.functions.special.tensor_functions.KroneckerDelta
  238. sympy.concrete.sums.summation
  239. """
  240. if ((limit[2] - limit[1]) < 0) == True:
  241. return S.Zero
  242. if not f.has(KroneckerDelta):
  243. return summation(f, limit)
  244. x = limit[0]
  245. g = _expand_delta(f, x)
  246. if g.is_Add:
  247. return piecewise_fold(
  248. g.func(*[deltasummation(h, limit, no_piecewise) for h in g.args]))
  249. # try to extract a simple KroneckerDelta term
  250. delta, expr = _extract_delta(g, x)
  251. if (delta is not None) and (delta.delta_range is not None):
  252. dinf, dsup = delta.delta_range
  253. if (limit[1] - dinf <= 0) == True and (limit[2] - dsup >= 0) == True:
  254. no_piecewise = True
  255. if not delta:
  256. return summation(f, limit)
  257. solns = solve(delta.args[0] - delta.args[1], x)
  258. if len(solns) == 0:
  259. return S.Zero
  260. elif len(solns) != 1:
  261. return Sum(f, limit)
  262. value = solns[0]
  263. if no_piecewise:
  264. return expr.subs(x, value)
  265. return Piecewise(
  266. (expr.subs(x, value), Interval(*limit[1:3]).as_relational(value)),
  267. (S.Zero, True)
  268. )