mixins.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. """
  2. Mixin classes for custom array types that don't inherit from ndarray.
  3. """
  4. from numpy._core import umath as um
  5. __all__ = ['NDArrayOperatorsMixin']
  6. def _disables_array_ufunc(obj):
  7. """True when __array_ufunc__ is set to None."""
  8. try:
  9. return obj.__array_ufunc__ is None
  10. except AttributeError:
  11. return False
  12. def _binary_method(ufunc, name):
  13. """Implement a forward binary method with a ufunc, e.g., __add__."""
  14. def func(self, other):
  15. if _disables_array_ufunc(other):
  16. return NotImplemented
  17. return ufunc(self, other)
  18. func.__name__ = '__{}__'.format(name)
  19. return func
  20. def _reflected_binary_method(ufunc, name):
  21. """Implement a reflected binary method with a ufunc, e.g., __radd__."""
  22. def func(self, other):
  23. if _disables_array_ufunc(other):
  24. return NotImplemented
  25. return ufunc(other, self)
  26. func.__name__ = '__r{}__'.format(name)
  27. return func
  28. def _inplace_binary_method(ufunc, name):
  29. """Implement an in-place binary method with a ufunc, e.g., __iadd__."""
  30. def func(self, other):
  31. return ufunc(self, other, out=(self,))
  32. func.__name__ = '__i{}__'.format(name)
  33. return func
  34. def _numeric_methods(ufunc, name):
  35. """Implement forward, reflected and inplace binary methods with a ufunc."""
  36. return (_binary_method(ufunc, name),
  37. _reflected_binary_method(ufunc, name),
  38. _inplace_binary_method(ufunc, name))
  39. def _unary_method(ufunc, name):
  40. """Implement a unary special method with a ufunc."""
  41. def func(self):
  42. return ufunc(self)
  43. func.__name__ = '__{}__'.format(name)
  44. return func
  45. class NDArrayOperatorsMixin:
  46. """Mixin defining all operator special methods using __array_ufunc__.
  47. This class implements the special methods for almost all of Python's
  48. builtin operators defined in the `operator` module, including comparisons
  49. (``==``, ``>``, etc.) and arithmetic (``+``, ``*``, ``-``, etc.), by
  50. deferring to the ``__array_ufunc__`` method, which subclasses must
  51. implement.
  52. It is useful for writing classes that do not inherit from `numpy.ndarray`,
  53. but that should support arithmetic and numpy universal functions like
  54. arrays as described in `A Mechanism for Overriding Ufuncs
  55. <https://numpy.org/neps/nep-0013-ufunc-overrides.html>`_.
  56. As an trivial example, consider this implementation of an ``ArrayLike``
  57. class that simply wraps a NumPy array and ensures that the result of any
  58. arithmetic operation is also an ``ArrayLike`` object:
  59. >>> import numbers
  60. >>> class ArrayLike(np.lib.mixins.NDArrayOperatorsMixin):
  61. ... def __init__(self, value):
  62. ... self.value = np.asarray(value)
  63. ...
  64. ... # One might also consider adding the built-in list type to this
  65. ... # list, to support operations like np.add(array_like, list)
  66. ... _HANDLED_TYPES = (np.ndarray, numbers.Number)
  67. ...
  68. ... def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
  69. ... out = kwargs.get('out', ())
  70. ... for x in inputs + out:
  71. ... # Only support operations with instances of
  72. ... # _HANDLED_TYPES. Use ArrayLike instead of type(self)
  73. ... # for isinstance to allow subclasses that don't
  74. ... # override __array_ufunc__ to handle ArrayLike objects.
  75. ... if not isinstance(
  76. ... x, self._HANDLED_TYPES + (ArrayLike,)
  77. ... ):
  78. ... return NotImplemented
  79. ...
  80. ... # Defer to the implementation of the ufunc
  81. ... # on unwrapped values.
  82. ... inputs = tuple(x.value if isinstance(x, ArrayLike) else x
  83. ... for x in inputs)
  84. ... if out:
  85. ... kwargs['out'] = tuple(
  86. ... x.value if isinstance(x, ArrayLike) else x
  87. ... for x in out)
  88. ... result = getattr(ufunc, method)(*inputs, **kwargs)
  89. ...
  90. ... if type(result) is tuple:
  91. ... # multiple return values
  92. ... return tuple(type(self)(x) for x in result)
  93. ... elif method == 'at':
  94. ... # no return value
  95. ... return None
  96. ... else:
  97. ... # one return value
  98. ... return type(self)(result)
  99. ...
  100. ... def __repr__(self):
  101. ... return '%s(%r)' % (type(self).__name__, self.value)
  102. In interactions between ``ArrayLike`` objects and numbers or numpy arrays,
  103. the result is always another ``ArrayLike``:
  104. >>> x = ArrayLike([1, 2, 3])
  105. >>> x - 1
  106. ArrayLike(array([0, 1, 2]))
  107. >>> 1 - x
  108. ArrayLike(array([ 0, -1, -2]))
  109. >>> np.arange(3) - x
  110. ArrayLike(array([-1, -1, -1]))
  111. >>> x - np.arange(3)
  112. ArrayLike(array([1, 1, 1]))
  113. Note that unlike ``numpy.ndarray``, ``ArrayLike`` does not allow operations
  114. with arbitrary, unrecognized types. This ensures that interactions with
  115. ArrayLike preserve a well-defined casting hierarchy.
  116. """
  117. __slots__ = ()
  118. # Like np.ndarray, this mixin class implements "Option 1" from the ufunc
  119. # overrides NEP.
  120. # comparisons don't have reflected and in-place versions
  121. __lt__ = _binary_method(um.less, 'lt')
  122. __le__ = _binary_method(um.less_equal, 'le')
  123. __eq__ = _binary_method(um.equal, 'eq')
  124. __ne__ = _binary_method(um.not_equal, 'ne')
  125. __gt__ = _binary_method(um.greater, 'gt')
  126. __ge__ = _binary_method(um.greater_equal, 'ge')
  127. # numeric methods
  128. __add__, __radd__, __iadd__ = _numeric_methods(um.add, 'add')
  129. __sub__, __rsub__, __isub__ = _numeric_methods(um.subtract, 'sub')
  130. __mul__, __rmul__, __imul__ = _numeric_methods(um.multiply, 'mul')
  131. __matmul__, __rmatmul__, __imatmul__ = _numeric_methods(
  132. um.matmul, 'matmul')
  133. # Python 3 does not use __div__, __rdiv__, or __idiv__
  134. __truediv__, __rtruediv__, __itruediv__ = _numeric_methods(
  135. um.true_divide, 'truediv')
  136. __floordiv__, __rfloordiv__, __ifloordiv__ = _numeric_methods(
  137. um.floor_divide, 'floordiv')
  138. __mod__, __rmod__, __imod__ = _numeric_methods(um.remainder, 'mod')
  139. __divmod__ = _binary_method(um.divmod, 'divmod')
  140. __rdivmod__ = _reflected_binary_method(um.divmod, 'divmod')
  141. # __idivmod__ does not exist
  142. # TODO: handle the optional third argument for __pow__?
  143. __pow__, __rpow__, __ipow__ = _numeric_methods(um.power, 'pow')
  144. __lshift__, __rlshift__, __ilshift__ = _numeric_methods(
  145. um.left_shift, 'lshift')
  146. __rshift__, __rrshift__, __irshift__ = _numeric_methods(
  147. um.right_shift, 'rshift')
  148. __and__, __rand__, __iand__ = _numeric_methods(um.bitwise_and, 'and')
  149. __xor__, __rxor__, __ixor__ = _numeric_methods(um.bitwise_xor, 'xor')
  150. __or__, __ror__, __ior__ = _numeric_methods(um.bitwise_or, 'or')
  151. # unary methods
  152. __neg__ = _unary_method(um.negative, 'neg')
  153. __pos__ = _unary_method(um.positive, 'pos')
  154. __abs__ = _unary_method(um.absolute, 'abs')
  155. __invert__ = _unary_method(um.invert, 'invert')