| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291 |
- """Module for differentiation using CSE."""
- from sympy import cse, Matrix, Derivative, MatrixBase
- from sympy.utilities.iterables import iterable
- def _remove_cse_from_derivative(replacements, reduced_expressions):
- """
- This function is designed to postprocess the output of a common subexpression
- elimination (CSE) operation. Specifically, it removes any CSE replacement
- symbols from the arguments of ``Derivative`` terms in the expression. This
- is necessary to ensure that the forward Jacobian function correctly handles
- derivative terms.
- Parameters
- ==========
- replacements : list of (Symbol, expression) pairs
- Replacement symbols and relative common subexpressions that have been
- replaced during a CSE operation.
- reduced_expressions : list of SymPy expressions
- The reduced expressions with all the replacements from the
- replacements list above.
- Returns
- =======
- processed_replacements : list of (Symbol, expression) pairs
- Processed replacement list, in the same format of the
- ``replacements`` input list.
- processed_reduced : list of SymPy expressions
- Processed reduced list, in the same format of the
- ``reduced_expressions`` input list.
- """
- def traverse(node, repl_dict):
- if isinstance(node, Derivative):
- return replace_all(node, repl_dict)
- if not node.args:
- return node
- new_args = [traverse(arg, repl_dict) for arg in node.args]
- return node.func(*new_args)
- def replace_all(node, repl_dict):
- result = node
- while True:
- free_symbols = result.free_symbols
- symbols_dict = {k: repl_dict[k] for k in free_symbols if k in repl_dict}
- if not symbols_dict:
- break
- result = result.xreplace(symbols_dict)
- return result
- repl_dict = dict(replacements)
- processed_replacements = [
- (rep_sym, traverse(sub_exp, repl_dict))
- for rep_sym, sub_exp in replacements
- ]
- processed_reduced = [
- red_exp.__class__([traverse(exp, repl_dict) for exp in red_exp])
- for red_exp in reduced_expressions
- ]
- return processed_replacements, processed_reduced
- def _forward_jacobian_cse(replacements, reduced_expr, wrt):
- """
- Core function to compute the Jacobian of an input Matrix of expressions
- through forward accumulation. Takes directly the output of a CSE operation
- (replacements and reduced_expr), and an iterable of variables (wrt) with
- respect to which to differentiate the reduced expression and returns the
- reduced Jacobian matrix and the ``replacements`` list.
- The function also returns a list of precomputed free symbols for each
- subexpression, which are useful in the substitution process.
- Parameters
- ==========
- replacements : list of (Symbol, expression) pairs
- Replacement symbols and relative common subexpressions that have been
- replaced during a CSE operation.
- reduced_expr : list of SymPy expressions
- The reduced expressions with all the replacements from the
- replacements list above.
- wrt : iterable
- Iterable of expressions with respect to which to compute the
- Jacobian matrix.
- Returns
- =======
- replacements : list of (Symbol, expression) pairs
- Replacement symbols and relative common subexpressions that have been
- replaced during a CSE operation. Compared to the input replacement list,
- the output one doesn't contain replacement symbols inside
- ``Derivative``'s arguments.
- jacobian : list of SymPy expressions
- The list only contains one element, which is the Jacobian matrix with
- elements in reduced form (replacement symbols are present).
- precomputed_fs: list
- List of sets, which store the free symbols present in each sub-expression.
- Useful in the substitution process.
- """
- if not isinstance(reduced_expr[0], MatrixBase):
- raise TypeError("``expr`` must be of matrix type")
- if not (reduced_expr[0].shape[0] == 1 or reduced_expr[0].shape[1] == 1):
- raise TypeError("``expr`` must be a row or a column matrix")
- if not iterable(wrt):
- raise TypeError("``wrt`` must be an iterable of variables")
- elif not isinstance(wrt, MatrixBase):
- wrt = Matrix(wrt)
- if not (wrt.shape[0] == 1 or wrt.shape[1] == 1):
- raise TypeError("``wrt`` must be a row or a column matrix")
- replacements, reduced_expr = _remove_cse_from_derivative(replacements, reduced_expr)
- if replacements:
- rep_sym, sub_expr = map(Matrix, zip(*replacements))
- else:
- rep_sym, sub_expr = Matrix([]), Matrix([])
- l_sub, l_wrt, l_red = len(sub_expr), len(wrt), len(reduced_expr[0])
- f1 = reduced_expr[0].__class__.from_dok(l_red, l_wrt,
- {
- (i, j): diff_value
- for i, r in enumerate(reduced_expr[0])
- for j, w in enumerate(wrt)
- if (diff_value := r.diff(w)) != 0
- },
- )
- if not replacements:
- return [], [f1], []
- f2 = Matrix.from_dok(l_red, l_sub,
- {
- (i, j): diff_value
- for i, (r, fs) in enumerate([(r, r.free_symbols) for r in reduced_expr[0]])
- for j, s in enumerate(rep_sym)
- if s in fs and (diff_value := r.diff(s)) != 0
- },
- )
- rep_sym_set = set(rep_sym)
- precomputed_fs = [s.free_symbols & rep_sym_set for s in sub_expr ]
- c_matrix = Matrix.from_dok(1, l_wrt,
- {(0, j): diff_value for j, w in enumerate(wrt)
- if (diff_value := sub_expr[0].diff(w)) != 0})
- for i in range(1, l_sub):
- bi_matrix = Matrix.from_dok(1, i,
- {(0, j): diff_value for j in range(i + 1)
- if rep_sym[j] in precomputed_fs[i]
- and (diff_value := sub_expr[i].diff(rep_sym[j])) != 0})
- ai_matrix = Matrix.from_dok(1, l_wrt,
- {(0, j): diff_value for j, w in enumerate(wrt)
- if (diff_value := sub_expr[i].diff(w)) != 0})
- if bi_matrix._rep.nnz():
- ci_matrix = bi_matrix.multiply(c_matrix).add(ai_matrix)
- c_matrix = Matrix.vstack(c_matrix, ci_matrix)
- else:
- c_matrix = Matrix.vstack(c_matrix, ai_matrix)
- jacobian = f2.multiply(c_matrix).add(f1)
- jacobian = [reduced_expr[0].__class__(jacobian)]
- return replacements, jacobian, precomputed_fs
- def _forward_jacobian_norm_in_cse_out(expr, wrt):
- """
- Function to compute the Jacobian of an input Matrix of expressions through
- forward accumulation. Takes a sympy Matrix of expressions (expr) as input
- and an iterable of variables (wrt) with respect to which to compute the
- Jacobian matrix. The matrix is returned in reduced form (containing
- replacement symbols) along with the ``replacements`` list.
- The function also returns a list of precomputed free symbols for each
- subexpression, which are useful in the substitution process.
- Parameters
- ==========
- expr : Matrix
- The vector to be differentiated.
- wrt : iterable
- The vector with respect to which to perform the differentiation.
- Can be a matrix or an iterable of variables.
- Returns
- =======
- replacements : list of (Symbol, expression) pairs
- Replacement symbols and relative common subexpressions that have been
- replaced during a CSE operation. The output replacement list doesn't
- contain replacement symbols inside ``Derivative``'s arguments.
- jacobian : list of SymPy expressions
- The list only contains one element, which is the Jacobian matrix with
- elements in reduced form (replacement symbols are present).
- precomputed_fs: list
- List of sets, which store the free symbols present in each
- sub-expression. Useful in the substitution process.
- """
- replacements, reduced_expr = cse(expr)
- replacements, jacobian, precomputed_fs = _forward_jacobian_cse(replacements, reduced_expr, wrt)
- return replacements, jacobian, precomputed_fs
- def _forward_jacobian(expr, wrt):
- """
- Function to compute the Jacobian of an input Matrix of expressions through
- forward accumulation. Takes a sympy Matrix of expressions (expr) as input
- and an iterable of variables (wrt) with respect to which to compute the
- Jacobian matrix.
- Explanation
- ===========
- Expressions often contain repeated subexpressions. Using a tree structure,
- these subexpressions are duplicated and differentiated multiple times,
- leading to inefficiency.
- Instead, if a data structure called a directed acyclic graph (DAG) is used
- then each of these repeated subexpressions will only exist a single time.
- This function uses a combination of representing the expression as a DAG and
- a forward accumulation algorithm (repeated application of the chain rule
- symbolically) to more efficiently calculate the Jacobian matrix of a target
- expression ``expr`` with respect to an expression or set of expressions
- ``wrt``.
- Note that this function is intended to improve performance when
- differentiating large expressions that contain many common subexpressions.
- For small and simple expressions it is likely less performant than using
- SymPy's standard differentiation functions and methods.
- Parameters
- ==========
- expr : Matrix
- The vector to be differentiated.
- wrt : iterable
- The vector with respect to which to do the differentiation.
- Can be a matrix or an iterable of variables.
- See Also
- ========
- Direct Acyclic Graph : https://en.wikipedia.org/wiki/Directed_acyclic_graph
- """
- replacements, reduced_expr = cse(expr)
- if replacements:
- rep_sym, _ = map(Matrix, zip(*replacements))
- else:
- rep_sym = Matrix([])
- replacements, jacobian, precomputed_fs = _forward_jacobian_cse(replacements, reduced_expr, wrt)
- if not replacements: return jacobian[0]
- sub_rep = dict(replacements)
- for i, ik in enumerate(precomputed_fs):
- sub_dict = {j: sub_rep[j] for j in ik}
- sub_rep[rep_sym[i]] = sub_rep[rep_sym[i]].xreplace(sub_dict)
- return jacobian[0].xreplace(sub_rep)
|