test_interface.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. # This test file tests the SymPy function interface, that people use to create
  2. # their own new functions. It should be as easy as possible.
  3. #
  4. # We test that it works with both Function and DefinedFunction. New code should
  5. # use DefinedFunction because it has better type inference. Old code still
  6. # using Function should continue to work though.
  7. from sympy.core.function import Function, DefinedFunction
  8. from sympy.core.sympify import sympify
  9. from sympy.functions.elementary.hyperbolic import tanh
  10. from sympy.functions.elementary.trigonometric import (cos, sin)
  11. from sympy.series.limits import limit
  12. from sympy.abc import x
  13. def test_function_series1():
  14. """Create our new "sin" function."""
  15. for F in [Function, DefinedFunction]:
  16. class my_function(F):
  17. def fdiff(self, argindex=1):
  18. return cos(self.args[0])
  19. @classmethod
  20. def eval(cls, arg):
  21. arg = sympify(arg)
  22. if arg == 0:
  23. return sympify(0)
  24. #Test that the taylor series is correct
  25. assert my_function(x).series(x, 0, 10) == sin(x).series(x, 0, 10)
  26. assert limit(my_function(x)/x, x, 0) == 1
  27. def test_function_series2():
  28. """Create our new "cos" function."""
  29. for F in [Function, DefinedFunction]:
  30. class my_function2(F):
  31. def fdiff(self, argindex=1):
  32. return -sin(self.args[0])
  33. @classmethod
  34. def eval(cls, arg):
  35. arg = sympify(arg)
  36. if arg == 0:
  37. return sympify(1)
  38. #Test that the taylor series is correct
  39. assert my_function2(x).series(x, 0, 10) == cos(x).series(x, 0, 10)
  40. def test_function_series3():
  41. """
  42. Test our easy "tanh" function.
  43. This test tests two things:
  44. * that the Function interface works as expected and it's easy to use
  45. * that the general algorithm for the series expansion works even when the
  46. derivative is defined recursively in terms of the original function,
  47. since tanh(x).diff(x) == 1-tanh(x)**2
  48. """
  49. for F in [Function, DefinedFunction]:
  50. class mytanh(F):
  51. def fdiff(self, argindex=1):
  52. return 1 - mytanh(self.args[0])**2
  53. @classmethod
  54. def eval(cls, arg):
  55. arg = sympify(arg)
  56. if arg == 0:
  57. return sympify(0)
  58. e = tanh(x)
  59. f = mytanh(x)
  60. assert e.series(x, 0, 6) == f.series(x, 0, 6)