test_matrix_nodes.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. from sympy.core.symbol import symbols
  2. from sympy.core.function import Function
  3. from sympy.matrices.dense import Matrix
  4. from sympy.matrices.dense import zeros
  5. from sympy.simplify.simplify import simplify
  6. from sympy.codegen.matrix_nodes import MatrixSolve
  7. from sympy.utilities.lambdify import lambdify
  8. from sympy.printing.numpy import NumPyPrinter
  9. from sympy.testing.pytest import skip
  10. from sympy.external import import_module
  11. def test_matrix_solve_issue_24862():
  12. A = Matrix(3, 3, symbols('a:9'))
  13. b = Matrix(3, 1, symbols('b:3'))
  14. hash(MatrixSolve(A, b))
  15. def test_matrix_solve_derivative_exact():
  16. q = symbols('q')
  17. a11, a12, a21, a22, b1, b2 = (
  18. f(q) for f in symbols('a11 a12 a21 a22 b1 b2', cls=Function))
  19. A = Matrix([[a11, a12], [a21, a22]])
  20. b = Matrix([b1, b2])
  21. x_lu = A.LUsolve(b)
  22. dxdq_lu = A.LUsolve(b.diff(q) - A.diff(q) * A.LUsolve(b))
  23. assert simplify(x_lu.diff(q) - dxdq_lu) == zeros(2, 1)
  24. # dxdq_ms is the MatrixSolve equivalent of dxdq_lu
  25. dxdq_ms = MatrixSolve(A, b.diff(q) - A.diff(q) * MatrixSolve(A, b))
  26. assert MatrixSolve(A, b).diff(q) == dxdq_ms
  27. def test_matrix_solve_derivative_numpy():
  28. np = import_module('numpy')
  29. if not np:
  30. skip("numpy not installed.")
  31. q = symbols('q')
  32. a11, a12, a21, a22, b1, b2 = (
  33. f(q) for f in symbols('a11 a12 a21 a22 b1 b2', cls=Function))
  34. A = Matrix([[a11, a12], [a21, a22]])
  35. b = Matrix([b1, b2])
  36. dx_lu = A.LUsolve(b).diff(q)
  37. subs = {a11.diff(q): 0.2, a12.diff(q): 0.3, a21.diff(q): 0.1,
  38. a22.diff(q): 0.5, b1.diff(q): 0.4, b2.diff(q): 0.9,
  39. a11: 1.3, a12: 0.5, a21: 1.2, a22: 4, b1: 6.2, b2: 3.5}
  40. p, p_vals = zip(*subs.items())
  41. dx_sm = MatrixSolve(A, b).diff(q)
  42. np.testing.assert_allclose(
  43. lambdify(p, dx_sm, printer=NumPyPrinter)(*p_vals),
  44. lambdify(p, dx_lu, printer=NumPyPrinter)(*p_vals))