_cse_diff.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. """Module for differentiation using CSE."""
  2. from sympy import cse, Matrix, Derivative, MatrixBase
  3. from sympy.utilities.iterables import iterable
  4. def _remove_cse_from_derivative(replacements, reduced_expressions):
  5. """
  6. This function is designed to postprocess the output of a common subexpression
  7. elimination (CSE) operation. Specifically, it removes any CSE replacement
  8. symbols from the arguments of ``Derivative`` terms in the expression. This
  9. is necessary to ensure that the forward Jacobian function correctly handles
  10. derivative terms.
  11. Parameters
  12. ==========
  13. replacements : list of (Symbol, expression) pairs
  14. Replacement symbols and relative common subexpressions that have been
  15. replaced during a CSE operation.
  16. reduced_expressions : list of SymPy expressions
  17. The reduced expressions with all the replacements from the
  18. replacements list above.
  19. Returns
  20. =======
  21. processed_replacements : list of (Symbol, expression) pairs
  22. Processed replacement list, in the same format of the
  23. ``replacements`` input list.
  24. processed_reduced : list of SymPy expressions
  25. Processed reduced list, in the same format of the
  26. ``reduced_expressions`` input list.
  27. """
  28. def traverse(node, repl_dict):
  29. if isinstance(node, Derivative):
  30. return replace_all(node, repl_dict)
  31. if not node.args:
  32. return node
  33. new_args = [traverse(arg, repl_dict) for arg in node.args]
  34. return node.func(*new_args)
  35. def replace_all(node, repl_dict):
  36. result = node
  37. while True:
  38. free_symbols = result.free_symbols
  39. symbols_dict = {k: repl_dict[k] for k in free_symbols if k in repl_dict}
  40. if not symbols_dict:
  41. break
  42. result = result.xreplace(symbols_dict)
  43. return result
  44. repl_dict = dict(replacements)
  45. processed_replacements = [
  46. (rep_sym, traverse(sub_exp, repl_dict))
  47. for rep_sym, sub_exp in replacements
  48. ]
  49. processed_reduced = [
  50. red_exp.__class__([traverse(exp, repl_dict) for exp in red_exp])
  51. for red_exp in reduced_expressions
  52. ]
  53. return processed_replacements, processed_reduced
  54. def _forward_jacobian_cse(replacements, reduced_expr, wrt):
  55. """
  56. Core function to compute the Jacobian of an input Matrix of expressions
  57. through forward accumulation. Takes directly the output of a CSE operation
  58. (replacements and reduced_expr), and an iterable of variables (wrt) with
  59. respect to which to differentiate the reduced expression and returns the
  60. reduced Jacobian matrix and the ``replacements`` list.
  61. The function also returns a list of precomputed free symbols for each
  62. subexpression, which are useful in the substitution process.
  63. Parameters
  64. ==========
  65. replacements : list of (Symbol, expression) pairs
  66. Replacement symbols and relative common subexpressions that have been
  67. replaced during a CSE operation.
  68. reduced_expr : list of SymPy expressions
  69. The reduced expressions with all the replacements from the
  70. replacements list above.
  71. wrt : iterable
  72. Iterable of expressions with respect to which to compute the
  73. Jacobian matrix.
  74. Returns
  75. =======
  76. replacements : list of (Symbol, expression) pairs
  77. Replacement symbols and relative common subexpressions that have been
  78. replaced during a CSE operation. Compared to the input replacement list,
  79. the output one doesn't contain replacement symbols inside
  80. ``Derivative``'s arguments.
  81. jacobian : list of SymPy expressions
  82. The list only contains one element, which is the Jacobian matrix with
  83. elements in reduced form (replacement symbols are present).
  84. precomputed_fs: list
  85. List of sets, which store the free symbols present in each sub-expression.
  86. Useful in the substitution process.
  87. """
  88. if not isinstance(reduced_expr[0], MatrixBase):
  89. raise TypeError("``expr`` must be of matrix type")
  90. if not (reduced_expr[0].shape[0] == 1 or reduced_expr[0].shape[1] == 1):
  91. raise TypeError("``expr`` must be a row or a column matrix")
  92. if not iterable(wrt):
  93. raise TypeError("``wrt`` must be an iterable of variables")
  94. elif not isinstance(wrt, MatrixBase):
  95. wrt = Matrix(wrt)
  96. if not (wrt.shape[0] == 1 or wrt.shape[1] == 1):
  97. raise TypeError("``wrt`` must be a row or a column matrix")
  98. replacements, reduced_expr = _remove_cse_from_derivative(replacements, reduced_expr)
  99. if replacements:
  100. rep_sym, sub_expr = map(Matrix, zip(*replacements))
  101. else:
  102. rep_sym, sub_expr = Matrix([]), Matrix([])
  103. l_sub, l_wrt, l_red = len(sub_expr), len(wrt), len(reduced_expr[0])
  104. f1 = reduced_expr[0].__class__.from_dok(l_red, l_wrt,
  105. {
  106. (i, j): diff_value
  107. for i, r in enumerate(reduced_expr[0])
  108. for j, w in enumerate(wrt)
  109. if (diff_value := r.diff(w)) != 0
  110. },
  111. )
  112. if not replacements:
  113. return [], [f1], []
  114. f2 = Matrix.from_dok(l_red, l_sub,
  115. {
  116. (i, j): diff_value
  117. for i, (r, fs) in enumerate([(r, r.free_symbols) for r in reduced_expr[0]])
  118. for j, s in enumerate(rep_sym)
  119. if s in fs and (diff_value := r.diff(s)) != 0
  120. },
  121. )
  122. rep_sym_set = set(rep_sym)
  123. precomputed_fs = [s.free_symbols & rep_sym_set for s in sub_expr ]
  124. c_matrix = Matrix.from_dok(1, l_wrt,
  125. {(0, j): diff_value for j, w in enumerate(wrt)
  126. if (diff_value := sub_expr[0].diff(w)) != 0})
  127. for i in range(1, l_sub):
  128. bi_matrix = Matrix.from_dok(1, i,
  129. {(0, j): diff_value for j in range(i + 1)
  130. if rep_sym[j] in precomputed_fs[i]
  131. and (diff_value := sub_expr[i].diff(rep_sym[j])) != 0})
  132. ai_matrix = Matrix.from_dok(1, l_wrt,
  133. {(0, j): diff_value for j, w in enumerate(wrt)
  134. if (diff_value := sub_expr[i].diff(w)) != 0})
  135. if bi_matrix._rep.nnz():
  136. ci_matrix = bi_matrix.multiply(c_matrix).add(ai_matrix)
  137. c_matrix = Matrix.vstack(c_matrix, ci_matrix)
  138. else:
  139. c_matrix = Matrix.vstack(c_matrix, ai_matrix)
  140. jacobian = f2.multiply(c_matrix).add(f1)
  141. jacobian = [reduced_expr[0].__class__(jacobian)]
  142. return replacements, jacobian, precomputed_fs
  143. def _forward_jacobian_norm_in_cse_out(expr, wrt):
  144. """
  145. Function to compute the Jacobian of an input Matrix of expressions through
  146. forward accumulation. Takes a sympy Matrix of expressions (expr) as input
  147. and an iterable of variables (wrt) with respect to which to compute the
  148. Jacobian matrix. The matrix is returned in reduced form (containing
  149. replacement symbols) along with the ``replacements`` list.
  150. The function also returns a list of precomputed free symbols for each
  151. subexpression, which are useful in the substitution process.
  152. Parameters
  153. ==========
  154. expr : Matrix
  155. The vector to be differentiated.
  156. wrt : iterable
  157. The vector with respect to which to perform the differentiation.
  158. Can be a matrix or an iterable of variables.
  159. Returns
  160. =======
  161. replacements : list of (Symbol, expression) pairs
  162. Replacement symbols and relative common subexpressions that have been
  163. replaced during a CSE operation. The output replacement list doesn't
  164. contain replacement symbols inside ``Derivative``'s arguments.
  165. jacobian : list of SymPy expressions
  166. The list only contains one element, which is the Jacobian matrix with
  167. elements in reduced form (replacement symbols are present).
  168. precomputed_fs: list
  169. List of sets, which store the free symbols present in each
  170. sub-expression. Useful in the substitution process.
  171. """
  172. replacements, reduced_expr = cse(expr)
  173. replacements, jacobian, precomputed_fs = _forward_jacobian_cse(replacements, reduced_expr, wrt)
  174. return replacements, jacobian, precomputed_fs
  175. def _forward_jacobian(expr, wrt):
  176. """
  177. Function to compute the Jacobian of an input Matrix of expressions through
  178. forward accumulation. Takes a sympy Matrix of expressions (expr) as input
  179. and an iterable of variables (wrt) with respect to which to compute the
  180. Jacobian matrix.
  181. Explanation
  182. ===========
  183. Expressions often contain repeated subexpressions. Using a tree structure,
  184. these subexpressions are duplicated and differentiated multiple times,
  185. leading to inefficiency.
  186. Instead, if a data structure called a directed acyclic graph (DAG) is used
  187. then each of these repeated subexpressions will only exist a single time.
  188. This function uses a combination of representing the expression as a DAG and
  189. a forward accumulation algorithm (repeated application of the chain rule
  190. symbolically) to more efficiently calculate the Jacobian matrix of a target
  191. expression ``expr`` with respect to an expression or set of expressions
  192. ``wrt``.
  193. Note that this function is intended to improve performance when
  194. differentiating large expressions that contain many common subexpressions.
  195. For small and simple expressions it is likely less performant than using
  196. SymPy's standard differentiation functions and methods.
  197. Parameters
  198. ==========
  199. expr : Matrix
  200. The vector to be differentiated.
  201. wrt : iterable
  202. The vector with respect to which to do the differentiation.
  203. Can be a matrix or an iterable of variables.
  204. See Also
  205. ========
  206. Direct Acyclic Graph : https://en.wikipedia.org/wiki/Directed_acyclic_graph
  207. """
  208. replacements, reduced_expr = cse(expr)
  209. if replacements:
  210. rep_sym, _ = map(Matrix, zip(*replacements))
  211. else:
  212. rep_sym = Matrix([])
  213. replacements, jacobian, precomputed_fs = _forward_jacobian_cse(replacements, reduced_expr, wrt)
  214. if not replacements: return jacobian[0]
  215. sub_rep = dict(replacements)
  216. for i, ik in enumerate(precomputed_fs):
  217. sub_dict = {j: sub_rep[j] for j in ik}
  218. sub_rep[rep_sym[i]] = sub_rep[rep_sym[i]].xreplace(sub_dict)
  219. return jacobian[0].xreplace(sub_rep)