_methods.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. """
  2. Array methods which are called by both the C-code for the method
  3. and the Python code for the NumPy-namespace function
  4. """
  5. import os
  6. import pickle
  7. import warnings
  8. from contextlib import nullcontext
  9. import numpy as np
  10. from numpy._core import multiarray as mu
  11. from numpy._core import umath as um
  12. from numpy._core.multiarray import asanyarray
  13. from numpy._core import numerictypes as nt
  14. from numpy._core import _exceptions
  15. from numpy._globals import _NoValue
  16. # save those O(100) nanoseconds!
  17. bool_dt = mu.dtype("bool")
  18. umr_maximum = um.maximum.reduce
  19. umr_minimum = um.minimum.reduce
  20. umr_sum = um.add.reduce
  21. umr_prod = um.multiply.reduce
  22. umr_bitwise_count = um.bitwise_count
  23. umr_any = um.logical_or.reduce
  24. umr_all = um.logical_and.reduce
  25. # Complex types to -> (2,)float view for fast-path computation in _var()
  26. _complex_to_float = {
  27. nt.dtype(nt.csingle) : nt.dtype(nt.single),
  28. nt.dtype(nt.cdouble) : nt.dtype(nt.double),
  29. }
  30. # Special case for windows: ensure double takes precedence
  31. if nt.dtype(nt.longdouble) != nt.dtype(nt.double):
  32. _complex_to_float.update({
  33. nt.dtype(nt.clongdouble) : nt.dtype(nt.longdouble),
  34. })
  35. # avoid keyword arguments to speed up parsing, saves about 15%-20% for very
  36. # small reductions
  37. def _amax(a, axis=None, out=None, keepdims=False,
  38. initial=_NoValue, where=True):
  39. return umr_maximum(a, axis, None, out, keepdims, initial, where)
  40. def _amin(a, axis=None, out=None, keepdims=False,
  41. initial=_NoValue, where=True):
  42. return umr_minimum(a, axis, None, out, keepdims, initial, where)
  43. def _sum(a, axis=None, dtype=None, out=None, keepdims=False,
  44. initial=_NoValue, where=True):
  45. return umr_sum(a, axis, dtype, out, keepdims, initial, where)
  46. def _prod(a, axis=None, dtype=None, out=None, keepdims=False,
  47. initial=_NoValue, where=True):
  48. return umr_prod(a, axis, dtype, out, keepdims, initial, where)
  49. def _any(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True):
  50. # By default, return a boolean for any and all
  51. if dtype is None:
  52. dtype = bool_dt
  53. # Parsing keyword arguments is currently fairly slow, so avoid it for now
  54. if where is True:
  55. return umr_any(a, axis, dtype, out, keepdims)
  56. return umr_any(a, axis, dtype, out, keepdims, where=where)
  57. def _all(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True):
  58. # By default, return a boolean for any and all
  59. if dtype is None:
  60. dtype = bool_dt
  61. # Parsing keyword arguments is currently fairly slow, so avoid it for now
  62. if where is True:
  63. return umr_all(a, axis, dtype, out, keepdims)
  64. return umr_all(a, axis, dtype, out, keepdims, where=where)
  65. def _count_reduce_items(arr, axis, keepdims=False, where=True):
  66. # fast-path for the default case
  67. if where is True:
  68. # no boolean mask given, calculate items according to axis
  69. if axis is None:
  70. axis = tuple(range(arr.ndim))
  71. elif not isinstance(axis, tuple):
  72. axis = (axis,)
  73. items = 1
  74. for ax in axis:
  75. items *= arr.shape[mu.normalize_axis_index(ax, arr.ndim)]
  76. items = nt.intp(items)
  77. else:
  78. # TODO: Optimize case when `where` is broadcast along a non-reduction
  79. # axis and full sum is more excessive than needed.
  80. # guarded to protect circular imports
  81. from numpy.lib._stride_tricks_impl import broadcast_to
  82. # count True values in (potentially broadcasted) boolean mask
  83. items = umr_sum(broadcast_to(where, arr.shape), axis, nt.intp, None,
  84. keepdims)
  85. return items
  86. def _clip(a, min=None, max=None, out=None, **kwargs):
  87. if a.dtype.kind in "iu":
  88. # If min/max is a Python integer, deal with out-of-bound values here.
  89. # (This enforces NEP 50 rules as no value based promotion is done.)
  90. if type(min) is int and min <= np.iinfo(a.dtype).min:
  91. min = None
  92. if type(max) is int and max >= np.iinfo(a.dtype).max:
  93. max = None
  94. if min is None and max is None:
  95. # return identity
  96. return um.positive(a, out=out, **kwargs)
  97. elif min is None:
  98. return um.minimum(a, max, out=out, **kwargs)
  99. elif max is None:
  100. return um.maximum(a, min, out=out, **kwargs)
  101. else:
  102. return um.clip(a, min, max, out=out, **kwargs)
  103. def _mean(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True):
  104. arr = asanyarray(a)
  105. is_float16_result = False
  106. rcount = _count_reduce_items(arr, axis, keepdims=keepdims, where=where)
  107. if rcount == 0 if where is True else umr_any(rcount == 0, axis=None):
  108. warnings.warn("Mean of empty slice.", RuntimeWarning, stacklevel=2)
  109. # Cast bool, unsigned int, and int to float64 by default
  110. if dtype is None:
  111. if issubclass(arr.dtype.type, (nt.integer, nt.bool)):
  112. dtype = mu.dtype('f8')
  113. elif issubclass(arr.dtype.type, nt.float16):
  114. dtype = mu.dtype('f4')
  115. is_float16_result = True
  116. ret = umr_sum(arr, axis, dtype, out, keepdims, where=where)
  117. if isinstance(ret, mu.ndarray):
  118. ret = um.true_divide(
  119. ret, rcount, out=ret, casting='unsafe', subok=False)
  120. if is_float16_result and out is None:
  121. ret = arr.dtype.type(ret)
  122. elif hasattr(ret, 'dtype'):
  123. if is_float16_result:
  124. ret = arr.dtype.type(ret / rcount)
  125. else:
  126. ret = ret.dtype.type(ret / rcount)
  127. else:
  128. ret = ret / rcount
  129. return ret
  130. def _var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *,
  131. where=True, mean=None):
  132. arr = asanyarray(a)
  133. rcount = _count_reduce_items(arr, axis, keepdims=keepdims, where=where)
  134. # Make this warning show up on top.
  135. if ddof >= rcount if where is True else umr_any(ddof >= rcount, axis=None):
  136. warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning,
  137. stacklevel=2)
  138. # Cast bool, unsigned int, and int to float64 by default
  139. if dtype is None and issubclass(arr.dtype.type, (nt.integer, nt.bool)):
  140. dtype = mu.dtype('f8')
  141. if mean is not None:
  142. arrmean = mean
  143. else:
  144. # Compute the mean.
  145. # Note that if dtype is not of inexact type then arraymean will
  146. # not be either.
  147. arrmean = umr_sum(arr, axis, dtype, keepdims=True, where=where)
  148. # The shape of rcount has to match arrmean to not change the shape of
  149. # out in broadcasting. Otherwise, it cannot be stored back to arrmean.
  150. if rcount.ndim == 0:
  151. # fast-path for default case when where is True
  152. div = rcount
  153. else:
  154. # matching rcount to arrmean when where is specified as array
  155. div = rcount.reshape(arrmean.shape)
  156. if isinstance(arrmean, mu.ndarray):
  157. arrmean = um.true_divide(arrmean, div, out=arrmean,
  158. casting='unsafe', subok=False)
  159. elif hasattr(arrmean, "dtype"):
  160. arrmean = arrmean.dtype.type(arrmean / rcount)
  161. else:
  162. arrmean = arrmean / rcount
  163. # Compute sum of squared deviations from mean
  164. # Note that x may not be inexact and that we need it to be an array,
  165. # not a scalar.
  166. x = asanyarray(arr - arrmean)
  167. if issubclass(arr.dtype.type, (nt.floating, nt.integer)):
  168. x = um.multiply(x, x, out=x)
  169. # Fast-paths for built-in complex types
  170. elif x.dtype in _complex_to_float:
  171. xv = x.view(dtype=(_complex_to_float[x.dtype], (2,)))
  172. um.multiply(xv, xv, out=xv)
  173. x = um.add(xv[..., 0], xv[..., 1], out=x.real).real
  174. # Most general case; includes handling object arrays containing imaginary
  175. # numbers and complex types with non-native byteorder
  176. else:
  177. x = um.multiply(x, um.conjugate(x), out=x).real
  178. ret = umr_sum(x, axis, dtype, out, keepdims=keepdims, where=where)
  179. # Compute degrees of freedom and make sure it is not negative.
  180. rcount = um.maximum(rcount - ddof, 0)
  181. # divide by degrees of freedom
  182. if isinstance(ret, mu.ndarray):
  183. ret = um.true_divide(
  184. ret, rcount, out=ret, casting='unsafe', subok=False)
  185. elif hasattr(ret, 'dtype'):
  186. ret = ret.dtype.type(ret / rcount)
  187. else:
  188. ret = ret / rcount
  189. return ret
  190. def _std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *,
  191. where=True, mean=None):
  192. ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  193. keepdims=keepdims, where=where, mean=mean)
  194. if isinstance(ret, mu.ndarray):
  195. ret = um.sqrt(ret, out=ret)
  196. elif hasattr(ret, 'dtype'):
  197. ret = ret.dtype.type(um.sqrt(ret))
  198. else:
  199. ret = um.sqrt(ret)
  200. return ret
  201. def _ptp(a, axis=None, out=None, keepdims=False):
  202. return um.subtract(
  203. umr_maximum(a, axis, None, out, keepdims),
  204. umr_minimum(a, axis, None, None, keepdims),
  205. out
  206. )
  207. def _dump(self, file, protocol=2):
  208. if hasattr(file, 'write'):
  209. ctx = nullcontext(file)
  210. else:
  211. ctx = open(os.fspath(file), "wb")
  212. with ctx as f:
  213. pickle.dump(self, f, protocol=protocol)
  214. def _dumps(self, protocol=2):
  215. return pickle.dumps(self, protocol=protocol)
  216. def _bitwise_count(a, out=None, *, where=True, casting='same_kind',
  217. order='K', dtype=None, subok=True):
  218. return umr_bitwise_count(a, out, where=where, casting=casting,
  219. order=order, dtype=dtype, subok=subok)