test_cse_diff.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. """Tests for the ``sympy.simplify._cse_diff.py`` module."""
  2. import pytest
  3. from sympy.core.symbol import (Symbol, symbols)
  4. from sympy.core.numbers import Integer
  5. from sympy.core.function import Function
  6. from sympy.core import Derivative
  7. from sympy.functions.elementary.exponential import exp
  8. from sympy.matrices.immutable import ImmutableDenseMatrix
  9. from sympy.physics.mechanics import dynamicsymbols
  10. from sympy.simplify._cse_diff import (_forward_jacobian,
  11. _remove_cse_from_derivative,
  12. _forward_jacobian_cse,
  13. _forward_jacobian_norm_in_cse_out)
  14. from sympy.simplify.simplify import simplify
  15. from sympy.matrices import Matrix, eye
  16. from sympy.testing.pytest import raises
  17. from sympy.functions.elementary.trigonometric import (cos, sin, tan)
  18. from sympy.simplify.trigsimp import trigsimp
  19. from sympy import cse
  20. w = Symbol('w')
  21. x = Symbol('x')
  22. y = Symbol('y')
  23. z = Symbol('z')
  24. q1, q2, q3 = dynamicsymbols('q1 q2 q3')
  25. # Define the custom functions
  26. k = Function('k')(x, y)
  27. f = Function('f')(k, z)
  28. zero = Integer(0)
  29. one = Integer(1)
  30. two = Integer(2)
  31. neg_one = Integer(-1)
  32. @pytest.mark.parametrize(
  33. 'expr, wrt',
  34. [
  35. ([zero], [x]),
  36. ([one], [x]),
  37. ([two], [x]),
  38. ([neg_one], [x]),
  39. ([x], [x]),
  40. ([y], [x]),
  41. ([x + y], [x]),
  42. ([x*y], [x]),
  43. ([x**2], [x]),
  44. ([x**y], [x]),
  45. ([exp(x)], [x]),
  46. ([sin(x)], [x]),
  47. ([tan(x)], [x]),
  48. ([zero, one, x, y, x*y, x + y], [x, y]),
  49. ([((x/y) + sin(x/y) - exp(y))*((x/y) - exp(y))], [x, y]),
  50. ([w*tan(y*z)/(x - tan(y*z)), w*x*tan(y*z)/(x - tan(y*z))], [w, x, y, z]),
  51. ([q1**2 + q2, q2**2 + q3, q3**2 + q1], [q1, q2, q3]),
  52. ([f + Derivative(f, x) + k + 2*x], [x])
  53. ]
  54. )
  55. def test_forward_jacobian(expr, wrt):
  56. expr = ImmutableDenseMatrix([expr]).T
  57. wrt = ImmutableDenseMatrix([wrt]).T
  58. jacobian = _forward_jacobian(expr, wrt)
  59. zeros = ImmutableDenseMatrix.zeros(*jacobian.shape)
  60. assert simplify(jacobian - expr.jacobian(wrt)) == zeros
  61. def test_process_cse():
  62. x, y, z = symbols('x y z')
  63. f = Function('f')
  64. k = Function('k')
  65. expr = Matrix([f(k(x,y), z) + Derivative(f(k(x,y), z), x) + k(x,y) + 2*x])
  66. repl, reduced = cse(expr)
  67. p_repl, p_reduced = _remove_cse_from_derivative(repl, reduced)
  68. x0 = symbols('x0')
  69. x1 = symbols('x1')
  70. expected_output = (
  71. [(x0, k(x, y)), (x1, f(x0, z))],
  72. [Matrix([2 * x + x0 + x1 + Derivative(f(k(x, y), z), x)])]
  73. )
  74. assert p_repl == expected_output[0], f"Expected {expected_output[0]}, but got {p_repl}"
  75. assert p_reduced == expected_output[1], f"Expected {expected_output[1]}, but got {p_reduced}"
  76. def test_io_matrix_type():
  77. x, y, z = symbols('x y z')
  78. expr = ImmutableDenseMatrix([
  79. x * y + y * z + x * y * z,
  80. x ** 2 + y ** 2 + z ** 2,
  81. x * y + x * z + y * z
  82. ])
  83. wrt = ImmutableDenseMatrix([x, y, z])
  84. replacements, reduced_expr = cse(expr)
  85. # Test _forward_jacobian_core
  86. replacements_core, jacobian_core, precomputed_fs_core = _forward_jacobian_cse(replacements, reduced_expr, wrt)
  87. assert isinstance(jacobian_core[0], type(reduced_expr[0])), "Jacobian should be a Matrix of the same type as the input"
  88. # Test _forward_jacobian_norm_in_dag_out
  89. replacements_norm, jacobian_norm, precomputed_fs_norm = _forward_jacobian_norm_in_cse_out(
  90. expr, wrt)
  91. assert isinstance(jacobian_norm[0], type(reduced_expr[0])), "Jacobian should be a Matrix of the same type as the input"
  92. # Test _forward_jacobian
  93. jacobian = _forward_jacobian(expr, wrt)
  94. assert isinstance(jacobian, type(expr)), "Jacobian should be a Matrix of the same type as the input"
  95. def test_forward_jacobian_input_output():
  96. x, y, z = symbols('x y z')
  97. expr = Matrix([
  98. x * y + y * z + x * y * z,
  99. x ** 2 + y ** 2 + z ** 2,
  100. x * y + x * z + y * z
  101. ])
  102. wrt = Matrix([x, y, z])
  103. replacements, reduced_expr = cse(expr)
  104. # Test _forward_jacobian_core
  105. replacements_core, jacobian_core, precomputed_fs_core = _forward_jacobian_cse(replacements, reduced_expr, wrt)
  106. assert isinstance(replacements_core, type(replacements)), "Replacements should be a list"
  107. assert isinstance(jacobian_core, type(reduced_expr)), "Jacobian should be a list"
  108. assert isinstance(precomputed_fs_core, list), "Precomputed free symbols should be a list"
  109. assert len(replacements_core) == len(replacements), "Length of replacements does not match"
  110. assert len(jacobian_core) == 1, "Jacobian should have one element"
  111. assert len(precomputed_fs_core) == len(replacements), "Length of precomputed free symbols does not match"
  112. # Test _forward_jacobian_norm_in_dag_out
  113. replacements_norm, jacobian_norm, precomputed_fs_norm = _forward_jacobian_norm_in_cse_out(expr, wrt)
  114. assert isinstance(replacements_norm, type(replacements)), "Replacements should be a list"
  115. assert isinstance(jacobian_norm, type(reduced_expr)), "Jacobian should be a list"
  116. assert isinstance(precomputed_fs_norm, list), "Precomputed free symbols should be a list"
  117. assert len(replacements_norm) == len(replacements), "Length of replacements does not match"
  118. assert len(jacobian_norm) == 1, "Jacobian should have one element"
  119. assert len(precomputed_fs_norm) == len(replacements), "Length of precomputed free symbols does not match"
  120. def test_jacobian_hessian():
  121. L = Matrix(1, 2, [x**2*y, 2*y**2 + x*y])
  122. syms = [x, y]
  123. assert _forward_jacobian(L, syms) == Matrix([[2*x*y, x**2], [y, 4*y + x]])
  124. L = Matrix(1, 2, [x, x**2*y**3])
  125. assert _forward_jacobian(L, syms) == Matrix([[1, 0], [2*x*y**3, x**2*3*y**2]])
  126. def test_jacobian_metrics():
  127. rho, phi = symbols("rho,phi")
  128. X = Matrix([rho * cos(phi), rho * sin(phi)])
  129. Y = Matrix([rho, phi])
  130. J = _forward_jacobian(X, Y)
  131. assert J == X.jacobian(Y.T)
  132. assert J == (X.T).jacobian(Y)
  133. assert J == (X.T).jacobian(Y.T)
  134. g = J.T * eye(J.shape[0]) * J
  135. g = g.applyfunc(trigsimp)
  136. assert g == Matrix([[1, 0], [0, rho ** 2]])
  137. def test_jacobian2():
  138. rho, phi = symbols("rho,phi")
  139. X = Matrix([rho * cos(phi), rho * sin(phi), rho ** 2])
  140. Y = Matrix([rho, phi])
  141. J = Matrix([
  142. [cos(phi), -rho * sin(phi)],
  143. [sin(phi), rho * cos(phi)],
  144. [2 * rho, 0],
  145. ])
  146. assert _forward_jacobian(X, Y) == J
  147. def test_issue_4564():
  148. X = Matrix([exp(x + y + z), exp(x + y + z), exp(x + y + z)])
  149. Y = Matrix([x, y, z])
  150. for i in range(1, 3):
  151. for j in range(1, 3):
  152. X_slice = X[:i, :]
  153. Y_slice = Y[:j, :]
  154. J = _forward_jacobian(X_slice, Y_slice)
  155. assert J.rows == i
  156. assert J.cols == j
  157. for k in range(j):
  158. assert J[:, k] == X_slice
  159. def test_nonvectorJacobian():
  160. X = Matrix([[exp(x + y + z), exp(x + y + z)],
  161. [exp(x + y + z), exp(x + y + z)]])
  162. raises(TypeError, lambda: _forward_jacobian(X, Matrix([x, y, z])))
  163. X = X[0, :]
  164. Y = Matrix([[x, y], [x, z]])
  165. raises(TypeError, lambda: _forward_jacobian(X, Y))
  166. raises(TypeError, lambda: _forward_jacobian(X, Matrix([[x, y], [x, z]])))