_logsumexp.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. import numpy as np
  2. from scipy._lib._array_api import (
  3. array_namespace,
  4. xp_capabilities,
  5. xp_device,
  6. xp_size,
  7. xp_promote,
  8. xp_float_to_complex,
  9. )
  10. from scipy._lib import array_api_extra as xpx
  11. __all__ = ["logsumexp", "softmax", "log_softmax"]
  12. @xp_capabilities()
  13. def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
  14. """Compute the log of the sum of exponentials of input elements.
  15. Parameters
  16. ----------
  17. a : array_like
  18. Input array.
  19. axis : None or int or tuple of ints, optional
  20. Axis or axes over which the sum is taken. By default `axis` is None,
  21. and all elements are summed.
  22. .. versionadded:: 0.11.0
  23. b : array-like, optional
  24. Scaling factor for exp(`a`) must be of the same shape as `a` or
  25. broadcastable to `a`. These values may be negative in order to
  26. implement subtraction.
  27. .. versionadded:: 0.12.0
  28. keepdims : bool, optional
  29. If this is set to True, the axes which are reduced are left in the
  30. result as dimensions with size one. With this option, the result
  31. will broadcast correctly against the original array.
  32. .. versionadded:: 0.15.0
  33. return_sign : bool, optional
  34. If this is set to True, the result will be a pair containing sign
  35. information; if False, results that are negative will be returned
  36. as NaN. Default is False (no sign information).
  37. .. versionadded:: 0.16.0
  38. Returns
  39. -------
  40. res : ndarray
  41. The result, ``np.log(np.sum(np.exp(a)))`` calculated in a numerically
  42. more stable way. If `b` is given then ``np.log(np.sum(b*np.exp(a)))``
  43. is returned. If ``return_sign`` is True, ``res`` contains the log of
  44. the absolute value of the argument.
  45. sgn : ndarray
  46. If ``return_sign`` is True, this will be an array of floating-point
  47. numbers matching res containing +1, 0, -1 (for real-valued inputs)
  48. or a complex phase (for complex inputs). This gives the sign of the
  49. argument of the logarithm in ``res``.
  50. If ``return_sign`` is False, only one result is returned.
  51. See Also
  52. --------
  53. :data:`numpy.logaddexp`
  54. :data:`numpy.logaddexp2`
  55. Notes
  56. -----
  57. NumPy has a logaddexp function which is very similar to `logsumexp`, but
  58. only handles two arguments. `logaddexp.reduce` is similar to this
  59. function, but may be less stable.
  60. The logarithm is a multivalued function: for each :math:`x` there is an
  61. infinite number of :math:`z` such that :math:`exp(z) = x`. The convention
  62. is to return the :math:`z` whose imaginary part lies in :math:`(-pi, pi]`.
  63. Examples
  64. --------
  65. >>> import numpy as np
  66. >>> from scipy.special import logsumexp
  67. >>> a = np.arange(10)
  68. >>> logsumexp(a)
  69. 9.4586297444267107
  70. >>> np.log(np.sum(np.exp(a)))
  71. 9.4586297444267107
  72. With weights
  73. >>> a = np.arange(10)
  74. >>> b = np.arange(10, 0, -1)
  75. >>> logsumexp(a, b=b)
  76. 9.9170178533034665
  77. >>> np.log(np.sum(b*np.exp(a)))
  78. 9.9170178533034647
  79. Returning a sign flag
  80. >>> logsumexp([1,2],b=[1,-1],return_sign=True)
  81. (1.5413248546129181, -1.0)
  82. Notice that `logsumexp` does not directly support masked arrays. To use it
  83. on a masked array, convert the mask into zero weights:
  84. >>> a = np.ma.array([np.log(2), 2, np.log(3)],
  85. ... mask=[False, True, False])
  86. >>> b = (~a.mask).astype(int)
  87. >>> logsumexp(a.data, b=b), np.log(5)
  88. 1.6094379124341005, 1.6094379124341005
  89. """
  90. xp = array_namespace(a, b)
  91. a, b = xp_promote(a, b, broadcast=True, force_floating=True, xp=xp)
  92. a = xpx.atleast_nd(a, ndim=1, xp=xp)
  93. b = xpx.atleast_nd(b, ndim=1, xp=xp) if b is not None else b
  94. axis = tuple(range(a.ndim)) if axis is None else axis
  95. if xp_size(a) != 0:
  96. with np.errstate(divide='ignore', invalid='ignore', over='ignore'):
  97. # Where result is infinite, we use the direct logsumexp calculation to
  98. # delegate edge case handling to the behavior of `xp.log` and `xp.exp`,
  99. # which should follow the C99 standard for complex values.
  100. b_exp_a = xp.exp(a) if b is None else b * xp.exp(a)
  101. sum_ = xp.sum(b_exp_a, axis=axis, keepdims=True)
  102. sgn_inf = xp.sign(sum_) if return_sign else None
  103. sum_ = xp.abs(sum_) if return_sign else sum_
  104. out_inf = xp.log(sum_)
  105. with np.errstate(divide='ignore', invalid='ignore'): # log of zero is OK
  106. out, sgn = _logsumexp(a, b, axis=axis, return_sign=return_sign, xp=xp)
  107. # Replace infinite results. This probably could be done with an
  108. # `apply_where`-like strategy to avoid redundant calculation, but currently
  109. # `apply_where` itself is only for elementwise functions.
  110. out_finite = xp.isfinite(out)
  111. out = xp.where(out_finite, out, out_inf)
  112. sgn = xp.where(out_finite, sgn, sgn_inf) if return_sign else sgn
  113. else:
  114. shape = np.asarray(a.shape) # NumPy is convenient for shape manipulation
  115. shape[axis] = 1
  116. out = xp.full(tuple(shape), -xp.inf, dtype=a.dtype, device=xp_device(a))
  117. sgn = xp.sign(out)
  118. if xp.isdtype(out.dtype, 'complex floating'):
  119. if return_sign:
  120. real = xp.real(sgn)
  121. imag = xp_float_to_complex(_wrap_radians(xp.imag(sgn), xp=xp), xp=xp)
  122. sgn = real + imag*1j
  123. else:
  124. real = xp.real(out)
  125. imag = xp_float_to_complex(_wrap_radians(xp.imag(out), xp=xp), xp=xp)
  126. out = real + imag*1j
  127. # Deal with shape details - reducing dimensions and convert 0-D to scalar for NumPy
  128. out = xp.squeeze(out, axis=axis) if not keepdims else out
  129. sgn = xp.squeeze(sgn, axis=axis) if (sgn is not None and not keepdims) else sgn
  130. out = out[()] if out.ndim == 0 else out
  131. sgn = sgn[()] if (sgn is not None and sgn.ndim == 0) else sgn
  132. return (out, sgn) if return_sign else out
  133. def _wrap_radians(x, *, xp):
  134. # Wrap radians to (-pi, pi] interval
  135. wrapped = -((-x + xp.pi) % (2 * xp.pi) - xp.pi)
  136. # preserve relative precision
  137. no_wrap = xp.abs(x) < xp.pi
  138. return xp.where(no_wrap, x, wrapped)
  139. def _elements_and_indices_with_max_real(a, *, axis=-1, xp):
  140. # This is an array-API compatible `max` function that works something
  141. # like `np.max` for complex input. The important part is that it finds
  142. # the element with maximum real part. When there are multiple complex values
  143. # with this real part, it doesn't matter which we choose.
  144. # We could use `argmax` on real component, but array API doesn't yet have
  145. # `take_along_axis`, and even if it did, we would have problems with axis tuples.
  146. # Feel free to rewrite! It's ugly, but it's not the purpose of the PR, and
  147. # it gets the job done.
  148. if xp.isdtype(a.dtype, "complex floating"):
  149. # select all elements with max real part.
  150. real_a = xp.real(a)
  151. max_ = xp.max(real_a, axis=axis, keepdims=True)
  152. mask = real_a == max_
  153. # Of those, choose one arbitrarily. This is a reasonably
  154. # simple, array-API compatible way of doing so that doesn't
  155. # have a problem with `axis` being a tuple or None.
  156. i = xp.reshape(xp.arange(xp_size(a), device=xp_device(a)), a.shape)
  157. i = xpx.at(i, ~mask).set(-1)
  158. max_i = xp.max(i, axis=axis, keepdims=True)
  159. mask = i == max_i
  160. a = xp.where(mask, a, 0.)
  161. max_ = xp.sum(a, axis=axis, dtype=a.dtype, keepdims=True)
  162. else:
  163. max_ = xp.max(a, axis=axis, keepdims=True)
  164. mask = a == max_
  165. return max_, mask
  166. def _logsumexp(a, b, *, axis, return_sign, xp):
  167. # This has been around for about a decade, so let's consider it a feature:
  168. # Even if element of `a` is infinite or NaN, it adds nothing to the sum if
  169. # the corresponding weight is zero.
  170. if b is not None:
  171. a = xpx.at(a, b == 0).set(-xp.inf, copy=True)
  172. # Find element with maximum real part, since this is what affects the magnitude
  173. # of the exponential. Possible enhancement: include log of `b` magnitude in `a`.
  174. a_max, i_max = _elements_and_indices_with_max_real(a, axis=axis, xp=xp)
  175. # for precision, these terms are separated out of the main sum.
  176. a = xpx.at(a, i_max).set(-xp.inf, copy=True if b is None else None)
  177. i_max_dt = xp.astype(i_max, a.dtype)
  178. # This is an inefficient way of getting `m` because it is the sum of a sparse
  179. # array; however, this is the simplest way I can think of to get the right shape.
  180. b_i_max = i_max_dt if b is None else b * i_max_dt
  181. m = xp.sum(b_i_max, axis=axis, keepdims=True, dtype=a.dtype)
  182. # Shift, exponentiate, scale, and sum
  183. exp = b * xp.exp(a - a_max) if b is not None else xp.exp(a - a_max)
  184. s = xp.sum(exp, axis=axis, keepdims=True, dtype=exp.dtype)
  185. s = xp.where(s == 0, s, s/m)
  186. # Separate sign/magnitude information
  187. # Originally, this was only performed if `return_sign=True`.
  188. # However, this is also needed if any elements of `m < 0` or `s < -1`.
  189. # An improvement would be to perform the calculations only on these entries.
  190. sgn = xp.sign(s + 1) * xp.sign(m)
  191. if xp.isdtype(s.dtype, "real floating"):
  192. # The log functions need positive arguments
  193. s = xp.where(s < -1, -s - 2, s)
  194. m = xp.abs(m)
  195. else:
  196. # `a_max` can have a sign component for complex input
  197. sgn = sgn * xp.exp(xp.imag(a_max) * 1.0j)
  198. # Take log and undo shift
  199. out = xp.log1p(s) + xp.log(m) + a_max
  200. if return_sign:
  201. out = xp.real(out)
  202. elif xp.isdtype(out.dtype, 'real floating'):
  203. out = xpx.at(out)[sgn < 0].set(xp.nan)
  204. return out, sgn
  205. @xp_capabilities()
  206. def softmax(x, axis=None):
  207. r"""Compute the softmax function.
  208. The softmax function transforms each element of a collection by
  209. computing the exponential of each element divided by the sum of the
  210. exponentials of all the elements. That is, if `x` is a one-dimensional
  211. numpy array::
  212. softmax(x) = np.exp(x)/sum(np.exp(x))
  213. Parameters
  214. ----------
  215. x : array_like
  216. Input array.
  217. axis : int or tuple of ints, optional
  218. Axis to compute values along. Default is None and softmax will be
  219. computed over the entire array `x`.
  220. Returns
  221. -------
  222. s : ndarray
  223. An array the same shape as `x`. The result will sum to 1 along the
  224. specified axis.
  225. Notes
  226. -----
  227. The formula for the softmax function :math:`\sigma(x)` for a vector
  228. :math:`x = \{x_0, x_1, ..., x_{n-1}\}` is
  229. .. math:: \sigma(x)_j = \frac{e^{x_j}}{\sum_k e^{x_k}}
  230. The `softmax` function is the gradient of `logsumexp`.
  231. The implementation uses shifting to avoid overflow. See [1]_ for more
  232. details.
  233. .. versionadded:: 1.2.0
  234. References
  235. ----------
  236. .. [1] P. Blanchard, D.J. Higham, N.J. Higham, "Accurately computing the
  237. log-sum-exp and softmax functions", IMA Journal of Numerical Analysis,
  238. Vol.41(4), :doi:`10.1093/imanum/draa038`.
  239. Examples
  240. --------
  241. >>> import numpy as np
  242. >>> from scipy.special import softmax
  243. >>> np.set_printoptions(precision=5)
  244. >>> x = np.array([[1, 0.5, 0.2, 3],
  245. ... [1, -1, 7, 3],
  246. ... [2, 12, 13, 3]])
  247. ...
  248. Compute the softmax transformation over the entire array.
  249. >>> m = softmax(x)
  250. >>> m
  251. array([[ 4.48309e-06, 2.71913e-06, 2.01438e-06, 3.31258e-05],
  252. [ 4.48309e-06, 6.06720e-07, 1.80861e-03, 3.31258e-05],
  253. [ 1.21863e-05, 2.68421e-01, 7.29644e-01, 3.31258e-05]])
  254. >>> m.sum()
  255. 1.0
  256. Compute the softmax transformation along the first axis (i.e., the
  257. columns).
  258. >>> m = softmax(x, axis=0)
  259. >>> m
  260. array([[ 2.11942e-01, 1.01300e-05, 2.75394e-06, 3.33333e-01],
  261. [ 2.11942e-01, 2.26030e-06, 2.47262e-03, 3.33333e-01],
  262. [ 5.76117e-01, 9.99988e-01, 9.97525e-01, 3.33333e-01]])
  263. >>> m.sum(axis=0)
  264. array([ 1., 1., 1., 1.])
  265. Compute the softmax transformation along the second axis (i.e., the rows).
  266. >>> m = softmax(x, axis=1)
  267. >>> m
  268. array([[ 1.05877e-01, 6.42177e-02, 4.75736e-02, 7.82332e-01],
  269. [ 2.42746e-03, 3.28521e-04, 9.79307e-01, 1.79366e-02],
  270. [ 1.22094e-05, 2.68929e-01, 7.31025e-01, 3.31885e-05]])
  271. >>> m.sum(axis=1)
  272. array([ 1., 1., 1.])
  273. """
  274. xp = array_namespace(x)
  275. x = xp.asarray(x)
  276. x_max = xp.max(x, axis=axis, keepdims=True)
  277. exp_x_shifted = xp.exp(x - x_max)
  278. return exp_x_shifted / xp.sum(exp_x_shifted, axis=axis, keepdims=True)
  279. @xp_capabilities()
  280. def log_softmax(x, axis=None):
  281. r"""Compute the logarithm of the softmax function.
  282. In principle::
  283. log_softmax(x) = log(softmax(x))
  284. but using a more accurate implementation.
  285. Parameters
  286. ----------
  287. x : array_like
  288. Input array.
  289. axis : int or tuple of ints, optional
  290. Axis to compute values along. Default is None and softmax will be
  291. computed over the entire array `x`.
  292. Returns
  293. -------
  294. s : ndarray or scalar
  295. An array with the same shape as `x`. Exponential of the result will
  296. sum to 1 along the specified axis. If `x` is a scalar, a scalar is
  297. returned.
  298. Notes
  299. -----
  300. `log_softmax` is more accurate than ``np.log(softmax(x))`` with inputs that
  301. make `softmax` saturate (see examples below).
  302. .. versionadded:: 1.5.0
  303. Examples
  304. --------
  305. >>> import numpy as np
  306. >>> from scipy.special import log_softmax
  307. >>> from scipy.special import softmax
  308. >>> np.set_printoptions(precision=5)
  309. >>> x = np.array([1000.0, 1.0])
  310. >>> y = log_softmax(x)
  311. >>> y
  312. array([ 0., -999.])
  313. >>> with np.errstate(divide='ignore'):
  314. ... y = np.log(softmax(x))
  315. ...
  316. >>> y
  317. array([ 0., -inf])
  318. """
  319. xp = array_namespace(x)
  320. x = xp.asarray(x)
  321. x_max = xp.max(x, axis=axis, keepdims=True)
  322. if x_max.ndim > 0:
  323. x_max = xpx.at(x_max, ~xp.isfinite(x_max)).set(0)
  324. elif not xp.isfinite(x_max):
  325. x_max = 0
  326. tmp = x - x_max
  327. exp_tmp = xp.exp(tmp)
  328. # suppress warnings about log of zero
  329. with np.errstate(divide='ignore'):
  330. s = xp.sum(exp_tmp, axis=axis, keepdims=True)
  331. out = xp.log(s)
  332. return tmp - out