| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420 |
- import numpy as np
- from scipy._lib._array_api import (
- array_namespace,
- xp_capabilities,
- xp_device,
- xp_size,
- xp_promote,
- xp_float_to_complex,
- )
- from scipy._lib import array_api_extra as xpx
- __all__ = ["logsumexp", "softmax", "log_softmax"]
- @xp_capabilities()
- def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
- """Compute the log of the sum of exponentials of input elements.
- Parameters
- ----------
- a : array_like
- Input array.
- axis : None or int or tuple of ints, optional
- Axis or axes over which the sum is taken. By default `axis` is None,
- and all elements are summed.
- .. versionadded:: 0.11.0
- b : array-like, optional
- Scaling factor for exp(`a`) must be of the same shape as `a` or
- broadcastable to `a`. These values may be negative in order to
- implement subtraction.
- .. versionadded:: 0.12.0
- keepdims : bool, optional
- If this is set to True, the axes which are reduced are left in the
- result as dimensions with size one. With this option, the result
- will broadcast correctly against the original array.
- .. versionadded:: 0.15.0
- return_sign : bool, optional
- If this is set to True, the result will be a pair containing sign
- information; if False, results that are negative will be returned
- as NaN. Default is False (no sign information).
- .. versionadded:: 0.16.0
- Returns
- -------
- res : ndarray
- The result, ``np.log(np.sum(np.exp(a)))`` calculated in a numerically
- more stable way. If `b` is given then ``np.log(np.sum(b*np.exp(a)))``
- is returned. If ``return_sign`` is True, ``res`` contains the log of
- the absolute value of the argument.
- sgn : ndarray
- If ``return_sign`` is True, this will be an array of floating-point
- numbers matching res containing +1, 0, -1 (for real-valued inputs)
- or a complex phase (for complex inputs). This gives the sign of the
- argument of the logarithm in ``res``.
- If ``return_sign`` is False, only one result is returned.
- See Also
- --------
- :data:`numpy.logaddexp`
- :data:`numpy.logaddexp2`
- Notes
- -----
- NumPy has a logaddexp function which is very similar to `logsumexp`, but
- only handles two arguments. `logaddexp.reduce` is similar to this
- function, but may be less stable.
- The logarithm is a multivalued function: for each :math:`x` there is an
- infinite number of :math:`z` such that :math:`exp(z) = x`. The convention
- is to return the :math:`z` whose imaginary part lies in :math:`(-pi, pi]`.
- Examples
- --------
- >>> import numpy as np
- >>> from scipy.special import logsumexp
- >>> a = np.arange(10)
- >>> logsumexp(a)
- 9.4586297444267107
- >>> np.log(np.sum(np.exp(a)))
- 9.4586297444267107
- With weights
- >>> a = np.arange(10)
- >>> b = np.arange(10, 0, -1)
- >>> logsumexp(a, b=b)
- 9.9170178533034665
- >>> np.log(np.sum(b*np.exp(a)))
- 9.9170178533034647
- Returning a sign flag
- >>> logsumexp([1,2],b=[1,-1],return_sign=True)
- (1.5413248546129181, -1.0)
- Notice that `logsumexp` does not directly support masked arrays. To use it
- on a masked array, convert the mask into zero weights:
- >>> a = np.ma.array([np.log(2), 2, np.log(3)],
- ... mask=[False, True, False])
- >>> b = (~a.mask).astype(int)
- >>> logsumexp(a.data, b=b), np.log(5)
- 1.6094379124341005, 1.6094379124341005
- """
- xp = array_namespace(a, b)
- a, b = xp_promote(a, b, broadcast=True, force_floating=True, xp=xp)
- a = xpx.atleast_nd(a, ndim=1, xp=xp)
- b = xpx.atleast_nd(b, ndim=1, xp=xp) if b is not None else b
- axis = tuple(range(a.ndim)) if axis is None else axis
- if xp_size(a) != 0:
- with np.errstate(divide='ignore', invalid='ignore', over='ignore'):
- # Where result is infinite, we use the direct logsumexp calculation to
- # delegate edge case handling to the behavior of `xp.log` and `xp.exp`,
- # which should follow the C99 standard for complex values.
- b_exp_a = xp.exp(a) if b is None else b * xp.exp(a)
- sum_ = xp.sum(b_exp_a, axis=axis, keepdims=True)
- sgn_inf = xp.sign(sum_) if return_sign else None
- sum_ = xp.abs(sum_) if return_sign else sum_
- out_inf = xp.log(sum_)
- with np.errstate(divide='ignore', invalid='ignore'): # log of zero is OK
- out, sgn = _logsumexp(a, b, axis=axis, return_sign=return_sign, xp=xp)
- # Replace infinite results. This probably could be done with an
- # `apply_where`-like strategy to avoid redundant calculation, but currently
- # `apply_where` itself is only for elementwise functions.
- out_finite = xp.isfinite(out)
- out = xp.where(out_finite, out, out_inf)
- sgn = xp.where(out_finite, sgn, sgn_inf) if return_sign else sgn
- else:
- shape = np.asarray(a.shape) # NumPy is convenient for shape manipulation
- shape[axis] = 1
- out = xp.full(tuple(shape), -xp.inf, dtype=a.dtype, device=xp_device(a))
- sgn = xp.sign(out)
- if xp.isdtype(out.dtype, 'complex floating'):
- if return_sign:
- real = xp.real(sgn)
- imag = xp_float_to_complex(_wrap_radians(xp.imag(sgn), xp=xp), xp=xp)
- sgn = real + imag*1j
- else:
- real = xp.real(out)
- imag = xp_float_to_complex(_wrap_radians(xp.imag(out), xp=xp), xp=xp)
- out = real + imag*1j
- # Deal with shape details - reducing dimensions and convert 0-D to scalar for NumPy
- out = xp.squeeze(out, axis=axis) if not keepdims else out
- sgn = xp.squeeze(sgn, axis=axis) if (sgn is not None and not keepdims) else sgn
- out = out[()] if out.ndim == 0 else out
- sgn = sgn[()] if (sgn is not None and sgn.ndim == 0) else sgn
- return (out, sgn) if return_sign else out
- def _wrap_radians(x, *, xp):
- # Wrap radians to (-pi, pi] interval
- wrapped = -((-x + xp.pi) % (2 * xp.pi) - xp.pi)
- # preserve relative precision
- no_wrap = xp.abs(x) < xp.pi
- return xp.where(no_wrap, x, wrapped)
- def _elements_and_indices_with_max_real(a, *, axis=-1, xp):
- # This is an array-API compatible `max` function that works something
- # like `np.max` for complex input. The important part is that it finds
- # the element with maximum real part. When there are multiple complex values
- # with this real part, it doesn't matter which we choose.
- # We could use `argmax` on real component, but array API doesn't yet have
- # `take_along_axis`, and even if it did, we would have problems with axis tuples.
- # Feel free to rewrite! It's ugly, but it's not the purpose of the PR, and
- # it gets the job done.
- if xp.isdtype(a.dtype, "complex floating"):
- # select all elements with max real part.
- real_a = xp.real(a)
- max_ = xp.max(real_a, axis=axis, keepdims=True)
- mask = real_a == max_
- # Of those, choose one arbitrarily. This is a reasonably
- # simple, array-API compatible way of doing so that doesn't
- # have a problem with `axis` being a tuple or None.
- i = xp.reshape(xp.arange(xp_size(a), device=xp_device(a)), a.shape)
- i = xpx.at(i, ~mask).set(-1)
- max_i = xp.max(i, axis=axis, keepdims=True)
- mask = i == max_i
- a = xp.where(mask, a, 0.)
- max_ = xp.sum(a, axis=axis, dtype=a.dtype, keepdims=True)
- else:
- max_ = xp.max(a, axis=axis, keepdims=True)
- mask = a == max_
- return max_, mask
- def _logsumexp(a, b, *, axis, return_sign, xp):
- # This has been around for about a decade, so let's consider it a feature:
- # Even if element of `a` is infinite or NaN, it adds nothing to the sum if
- # the corresponding weight is zero.
- if b is not None:
- a = xpx.at(a, b == 0).set(-xp.inf, copy=True)
- # Find element with maximum real part, since this is what affects the magnitude
- # of the exponential. Possible enhancement: include log of `b` magnitude in `a`.
- a_max, i_max = _elements_and_indices_with_max_real(a, axis=axis, xp=xp)
- # for precision, these terms are separated out of the main sum.
- a = xpx.at(a, i_max).set(-xp.inf, copy=True if b is None else None)
- i_max_dt = xp.astype(i_max, a.dtype)
- # This is an inefficient way of getting `m` because it is the sum of a sparse
- # array; however, this is the simplest way I can think of to get the right shape.
- b_i_max = i_max_dt if b is None else b * i_max_dt
- m = xp.sum(b_i_max, axis=axis, keepdims=True, dtype=a.dtype)
- # Shift, exponentiate, scale, and sum
- exp = b * xp.exp(a - a_max) if b is not None else xp.exp(a - a_max)
- s = xp.sum(exp, axis=axis, keepdims=True, dtype=exp.dtype)
- s = xp.where(s == 0, s, s/m)
- # Separate sign/magnitude information
- # Originally, this was only performed if `return_sign=True`.
- # However, this is also needed if any elements of `m < 0` or `s < -1`.
- # An improvement would be to perform the calculations only on these entries.
- sgn = xp.sign(s + 1) * xp.sign(m)
- if xp.isdtype(s.dtype, "real floating"):
- # The log functions need positive arguments
- s = xp.where(s < -1, -s - 2, s)
- m = xp.abs(m)
- else:
- # `a_max` can have a sign component for complex input
- sgn = sgn * xp.exp(xp.imag(a_max) * 1.0j)
- # Take log and undo shift
- out = xp.log1p(s) + xp.log(m) + a_max
- if return_sign:
- out = xp.real(out)
- elif xp.isdtype(out.dtype, 'real floating'):
- out = xpx.at(out)[sgn < 0].set(xp.nan)
- return out, sgn
- @xp_capabilities()
- def softmax(x, axis=None):
- r"""Compute the softmax function.
- The softmax function transforms each element of a collection by
- computing the exponential of each element divided by the sum of the
- exponentials of all the elements. That is, if `x` is a one-dimensional
- numpy array::
- softmax(x) = np.exp(x)/sum(np.exp(x))
- Parameters
- ----------
- x : array_like
- Input array.
- axis : int or tuple of ints, optional
- Axis to compute values along. Default is None and softmax will be
- computed over the entire array `x`.
- Returns
- -------
- s : ndarray
- An array the same shape as `x`. The result will sum to 1 along the
- specified axis.
- Notes
- -----
- The formula for the softmax function :math:`\sigma(x)` for a vector
- :math:`x = \{x_0, x_1, ..., x_{n-1}\}` is
- .. math:: \sigma(x)_j = \frac{e^{x_j}}{\sum_k e^{x_k}}
- The `softmax` function is the gradient of `logsumexp`.
- The implementation uses shifting to avoid overflow. See [1]_ for more
- details.
- .. versionadded:: 1.2.0
- References
- ----------
- .. [1] P. Blanchard, D.J. Higham, N.J. Higham, "Accurately computing the
- log-sum-exp and softmax functions", IMA Journal of Numerical Analysis,
- Vol.41(4), :doi:`10.1093/imanum/draa038`.
- Examples
- --------
- >>> import numpy as np
- >>> from scipy.special import softmax
- >>> np.set_printoptions(precision=5)
- >>> x = np.array([[1, 0.5, 0.2, 3],
- ... [1, -1, 7, 3],
- ... [2, 12, 13, 3]])
- ...
- Compute the softmax transformation over the entire array.
- >>> m = softmax(x)
- >>> m
- array([[ 4.48309e-06, 2.71913e-06, 2.01438e-06, 3.31258e-05],
- [ 4.48309e-06, 6.06720e-07, 1.80861e-03, 3.31258e-05],
- [ 1.21863e-05, 2.68421e-01, 7.29644e-01, 3.31258e-05]])
- >>> m.sum()
- 1.0
- Compute the softmax transformation along the first axis (i.e., the
- columns).
- >>> m = softmax(x, axis=0)
- >>> m
- array([[ 2.11942e-01, 1.01300e-05, 2.75394e-06, 3.33333e-01],
- [ 2.11942e-01, 2.26030e-06, 2.47262e-03, 3.33333e-01],
- [ 5.76117e-01, 9.99988e-01, 9.97525e-01, 3.33333e-01]])
- >>> m.sum(axis=0)
- array([ 1., 1., 1., 1.])
- Compute the softmax transformation along the second axis (i.e., the rows).
- >>> m = softmax(x, axis=1)
- >>> m
- array([[ 1.05877e-01, 6.42177e-02, 4.75736e-02, 7.82332e-01],
- [ 2.42746e-03, 3.28521e-04, 9.79307e-01, 1.79366e-02],
- [ 1.22094e-05, 2.68929e-01, 7.31025e-01, 3.31885e-05]])
- >>> m.sum(axis=1)
- array([ 1., 1., 1.])
- """
- xp = array_namespace(x)
- x = xp.asarray(x)
- x_max = xp.max(x, axis=axis, keepdims=True)
- exp_x_shifted = xp.exp(x - x_max)
- return exp_x_shifted / xp.sum(exp_x_shifted, axis=axis, keepdims=True)
- @xp_capabilities()
- def log_softmax(x, axis=None):
- r"""Compute the logarithm of the softmax function.
- In principle::
- log_softmax(x) = log(softmax(x))
- but using a more accurate implementation.
- Parameters
- ----------
- x : array_like
- Input array.
- axis : int or tuple of ints, optional
- Axis to compute values along. Default is None and softmax will be
- computed over the entire array `x`.
- Returns
- -------
- s : ndarray or scalar
- An array with the same shape as `x`. Exponential of the result will
- sum to 1 along the specified axis. If `x` is a scalar, a scalar is
- returned.
- Notes
- -----
- `log_softmax` is more accurate than ``np.log(softmax(x))`` with inputs that
- make `softmax` saturate (see examples below).
- .. versionadded:: 1.5.0
- Examples
- --------
- >>> import numpy as np
- >>> from scipy.special import log_softmax
- >>> from scipy.special import softmax
- >>> np.set_printoptions(precision=5)
- >>> x = np.array([1000.0, 1.0])
- >>> y = log_softmax(x)
- >>> y
- array([ 0., -999.])
- >>> with np.errstate(divide='ignore'):
- ... y = np.log(softmax(x))
- ...
- >>> y
- array([ 0., -inf])
- """
- xp = array_namespace(x)
- x = xp.asarray(x)
- x_max = xp.max(x, axis=axis, keepdims=True)
- if x_max.ndim > 0:
- x_max = xpx.at(x_max, ~xp.isfinite(x_max)).set(0)
- elif not xp.isfinite(x_max):
- x_max = 0
- tmp = x - x_max
- exp_tmp = xp.exp(tmp)
- # suppress warnings about log of zero
- with np.errstate(divide='ignore'):
- s = xp.sum(exp_tmp, axis=axis, keepdims=True)
- out = xp.log(s)
- return tmp - out
|