numpy_nodes.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. from sympy.core.function import Add, ArgumentIndexError, Function
  2. from sympy.core.power import Pow
  3. from sympy.core.singleton import S
  4. from sympy.core.sorting import default_sort_key
  5. from sympy.core.sympify import sympify
  6. from sympy.functions.elementary.exponential import exp, log
  7. from sympy.functions.elementary.miscellaneous import Max, Min
  8. from .ast import Token, none
  9. def _logaddexp(x1, x2, *, evaluate=True):
  10. return log(Add(exp(x1, evaluate=evaluate), exp(x2, evaluate=evaluate), evaluate=evaluate))
  11. _two = S.One*2
  12. _ln2 = log(_two)
  13. def _lb(x, *, evaluate=True):
  14. return log(x, evaluate=evaluate)/_ln2
  15. def _exp2(x, *, evaluate=True):
  16. return Pow(_two, x, evaluate=evaluate)
  17. def _logaddexp2(x1, x2, *, evaluate=True):
  18. return _lb(Add(_exp2(x1, evaluate=evaluate),
  19. _exp2(x2, evaluate=evaluate), evaluate=evaluate))
  20. class logaddexp(Function):
  21. """ Logarithm of the sum of exponentiations of the inputs.
  22. Helper class for use with e.g. numpy.logaddexp
  23. See Also
  24. ========
  25. https://numpy.org/doc/stable/reference/generated/numpy.logaddexp.html
  26. """
  27. nargs = 2
  28. def __new__(cls, *args):
  29. return Function.__new__(cls, *sorted(args, key=default_sort_key))
  30. def fdiff(self, argindex=1):
  31. """
  32. Returns the first derivative of this function.
  33. """
  34. if argindex == 1:
  35. wrt, other = self.args
  36. elif argindex == 2:
  37. other, wrt = self.args
  38. else:
  39. raise ArgumentIndexError(self, argindex)
  40. return S.One/(S.One + exp(other-wrt))
  41. def _eval_rewrite_as_log(self, x1, x2, **kwargs):
  42. return _logaddexp(x1, x2)
  43. def _eval_evalf(self, *args, **kwargs):
  44. return self.rewrite(log).evalf(*args, **kwargs)
  45. def _eval_simplify(self, *args, **kwargs):
  46. a, b = (x.simplify(**kwargs) for x in self.args)
  47. candidate = _logaddexp(a, b)
  48. if candidate != _logaddexp(a, b, evaluate=False):
  49. return candidate
  50. else:
  51. return logaddexp(a, b)
  52. class logaddexp2(Function):
  53. """ Logarithm of the sum of exponentiations of the inputs in base-2.
  54. Helper class for use with e.g. numpy.logaddexp2
  55. See Also
  56. ========
  57. https://numpy.org/doc/stable/reference/generated/numpy.logaddexp2.html
  58. """
  59. nargs = 2
  60. def __new__(cls, *args):
  61. return Function.__new__(cls, *sorted(args, key=default_sort_key))
  62. def fdiff(self, argindex=1):
  63. """
  64. Returns the first derivative of this function.
  65. """
  66. if argindex == 1:
  67. wrt, other = self.args
  68. elif argindex == 2:
  69. other, wrt = self.args
  70. else:
  71. raise ArgumentIndexError(self, argindex)
  72. return S.One/(S.One + _exp2(other-wrt))
  73. def _eval_rewrite_as_log(self, x1, x2, **kwargs):
  74. return _logaddexp2(x1, x2)
  75. def _eval_evalf(self, *args, **kwargs):
  76. return self.rewrite(log).evalf(*args, **kwargs)
  77. def _eval_simplify(self, *args, **kwargs):
  78. a, b = (x.simplify(**kwargs).factor() for x in self.args)
  79. candidate = _logaddexp2(a, b)
  80. if candidate != _logaddexp2(a, b, evaluate=False):
  81. return candidate
  82. else:
  83. return logaddexp2(a, b)
  84. class amin(Token):
  85. """ Minimum value along an axis.
  86. Helper class for use with e.g. numpy.amin
  87. See Also
  88. ========
  89. https://numpy.org/doc/stable/reference/generated/numpy.amin.html
  90. """
  91. __slots__ = _fields = ('array', 'axis')
  92. defaults = {'axis': none}
  93. _construct_axis = staticmethod(sympify)
  94. class amax(Token):
  95. """ Maximum value along an axis.
  96. Helper class for use with e.g. numpy.amax
  97. See Also
  98. ========
  99. https://numpy.org/doc/stable/reference/generated/numpy.amax.html
  100. """
  101. __slots__ = _fields = ('array', 'axis')
  102. defaults = {'axis': none}
  103. _construct_axis = staticmethod(sympify)
  104. class maximum(Function):
  105. """ Element-wise maximum of array elements.
  106. Helper class for use with e.g. numpy.maximum
  107. See Also
  108. ========
  109. https://numpy.org/doc/stable/reference/generated/numpy.maximum.html
  110. """
  111. def _eval_rewrite_as_Max(self, *args):
  112. return Max(*self.args)
  113. class minimum(Function):
  114. """ Element-wise minimum of array elements.
  115. Helper class for use with e.g. numpy.minimum
  116. See Also
  117. ========
  118. https://numpy.org/doc/stable/reference/generated/numpy.minimum.html
  119. """
  120. def _eval_rewrite_as_Min(self, *args):
  121. return Min(*self.args)