_variation.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import warnings
  2. import numpy as np
  3. from scipy._lib._array_api import (
  4. array_namespace,
  5. xp_capabilities,
  6. xp_device,
  7. _length_nonmasked,
  8. )
  9. import scipy._lib.array_api_extra as xpx
  10. from ._axis_nan_policy import _axis_nan_policy_factory
  11. @xp_capabilities()
  12. @_axis_nan_policy_factory(
  13. lambda x: x, n_outputs=1, result_to_tuple=lambda x, _: (x,)
  14. )
  15. def variation(a, axis=0, nan_policy='propagate', ddof=0, *, keepdims=False):
  16. """
  17. Compute the coefficient of variation.
  18. The coefficient of variation is the standard deviation divided by the
  19. mean. This function is equivalent to::
  20. np.std(x, axis=axis, ddof=ddof) / np.mean(x)
  21. The default for ``ddof`` is 0, but many definitions of the coefficient
  22. of variation use the square root of the unbiased sample variance
  23. for the sample standard deviation, which corresponds to ``ddof=1``.
  24. The function does not take the absolute value of the mean of the data,
  25. so the return value is negative if the mean is negative.
  26. Parameters
  27. ----------
  28. a : array_like
  29. Input array.
  30. axis : int or None, optional
  31. Axis along which to calculate the coefficient of variation.
  32. Default is 0. If None, compute over the whole array `a`.
  33. nan_policy : {'propagate', 'raise', 'omit'}, optional
  34. Defines how to handle when input contains ``nan``.
  35. The following options are available:
  36. * 'propagate': return ``nan``
  37. * 'raise': raise an exception
  38. * 'omit': perform the calculation with ``nan`` values omitted
  39. The default is 'propagate'.
  40. ddof : int, optional
  41. Gives the "Delta Degrees Of Freedom" used when computing the
  42. standard deviation. The divisor used in the calculation of the
  43. standard deviation is ``N - ddof``, where ``N`` is the number of
  44. elements. `ddof` must be less than ``N``; if it isn't, the result
  45. will be ``nan`` or ``inf``, depending on ``N`` and the values in
  46. the array. By default `ddof` is zero for backwards compatibility,
  47. but it is recommended to use ``ddof=1`` to ensure that the sample
  48. standard deviation is computed as the square root of the unbiased
  49. sample variance.
  50. Returns
  51. -------
  52. variation : ndarray
  53. The calculated variation along the requested axis.
  54. Notes
  55. -----
  56. There are several edge cases that are handled without generating a
  57. warning:
  58. * If both the mean and the standard deviation are zero, ``nan``
  59. is returned.
  60. * If the mean is zero and the standard deviation is nonzero, ``inf``
  61. is returned.
  62. * If the input has length zero (either because the array has zero
  63. length, or all the input values are ``nan`` and ``nan_policy`` is
  64. ``'omit'``), ``nan`` is returned.
  65. * If the input contains ``inf``, ``nan`` is returned.
  66. References
  67. ----------
  68. .. [1] Zwillinger, D. and Kokoska, S. (2000). CRC Standard
  69. Probability and Statistics Tables and Formulae. Chapman & Hall: New
  70. York. 2000.
  71. Examples
  72. --------
  73. >>> import numpy as np
  74. >>> from scipy.stats import variation
  75. >>> variation([1, 2, 3, 4, 5], ddof=1)
  76. 0.5270462766947299
  77. Compute the variation along a given dimension of an array that contains
  78. a few ``nan`` values:
  79. >>> x = np.array([[ 10.0, np.nan, 11.0, 19.0, 23.0, 29.0, 98.0],
  80. ... [ 29.0, 30.0, 32.0, 33.0, 35.0, 56.0, 57.0],
  81. ... [np.nan, np.nan, 12.0, 13.0, 16.0, 16.0, 17.0]])
  82. >>> variation(x, axis=1, ddof=1, nan_policy='omit')
  83. array([1.05109361, 0.31428986, 0.146483 ])
  84. """
  85. xp = array_namespace(a)
  86. a = xp.asarray(a)
  87. # `nan_policy` and `keepdims` are handled by `_axis_nan_policy`
  88. if axis is None:
  89. a = xp.reshape(a, (-1,))
  90. axis = 0
  91. n = xp.asarray(_length_nonmasked(a, axis=axis), dtype=a.dtype, device=xp_device(a))
  92. with (np.errstate(divide='ignore', invalid='ignore'), warnings.catch_warnings()):
  93. warnings.simplefilter("ignore")
  94. mean_a = xp.mean(a, axis=axis)
  95. std_a = xp.std(a, axis=axis)
  96. correction = (n / (n - ddof))**0.5 # we may need uncorrected std below
  97. result = std_a * correction / mean_a
  98. def special_case(std_a, mean_a):
  99. # xref data-apis/array-api-extra#196
  100. mxp = array_namespace(std_a, mean_a)
  101. # `_xp_inf` is a workaround for torch.copysign not accepting a scalar yet,
  102. # xref data-apis/array-api-compat#271
  103. _xp_inf = mxp.asarray(mxp.inf, dtype=mean_a.dtype, device=xp_device(mean_a))
  104. return mxp.where(std_a > 0, mxp.copysign(_xp_inf, mean_a), mxp.nan)
  105. result = xpx.apply_where((ddof == n), (std_a, mean_a),
  106. special_case, fill_value=result)
  107. return result[()] if result.ndim == 0 else result