_array_like.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. from __future__ import annotations
  2. import sys
  3. from collections.abc import Collection, Callable, Sequence
  4. from typing import Any, Protocol, TypeAlias, TypeVar, runtime_checkable, TYPE_CHECKING
  5. import numpy as np
  6. from numpy import (
  7. ndarray,
  8. dtype,
  9. generic,
  10. unsignedinteger,
  11. integer,
  12. floating,
  13. complexfloating,
  14. number,
  15. timedelta64,
  16. datetime64,
  17. object_,
  18. void,
  19. str_,
  20. bytes_,
  21. )
  22. from ._nbit_base import _32Bit, _64Bit
  23. from ._nested_sequence import _NestedSequence
  24. from ._shape import _Shape
  25. if TYPE_CHECKING:
  26. StringDType = np.dtypes.StringDType
  27. else:
  28. # at runtime outside of type checking importing this from numpy.dtypes
  29. # would lead to a circular import
  30. from numpy._core.multiarray import StringDType
  31. _T = TypeVar("_T")
  32. _ScalarType = TypeVar("_ScalarType", bound=generic)
  33. _ScalarType_co = TypeVar("_ScalarType_co", bound=generic, covariant=True)
  34. _DType = TypeVar("_DType", bound=dtype[Any])
  35. _DType_co = TypeVar("_DType_co", covariant=True, bound=dtype[Any])
  36. NDArray: TypeAlias = ndarray[_Shape, dtype[_ScalarType_co]]
  37. # The `_SupportsArray` protocol only cares about the default dtype
  38. # (i.e. `dtype=None` or no `dtype` parameter at all) of the to-be returned
  39. # array.
  40. # Concrete implementations of the protocol are responsible for adding
  41. # any and all remaining overloads
  42. @runtime_checkable
  43. class _SupportsArray(Protocol[_DType_co]):
  44. def __array__(self) -> ndarray[Any, _DType_co]: ...
  45. @runtime_checkable
  46. class _SupportsArrayFunc(Protocol):
  47. """A protocol class representing `~class.__array_function__`."""
  48. def __array_function__(
  49. self,
  50. func: Callable[..., Any],
  51. types: Collection[type[Any]],
  52. args: tuple[Any, ...],
  53. kwargs: dict[str, Any],
  54. ) -> object: ...
  55. # TODO: Wait until mypy supports recursive objects in combination with typevars
  56. _FiniteNestedSequence: TypeAlias = (
  57. _T
  58. | Sequence[_T]
  59. | Sequence[Sequence[_T]]
  60. | Sequence[Sequence[Sequence[_T]]]
  61. | Sequence[Sequence[Sequence[Sequence[_T]]]]
  62. )
  63. # A subset of `npt.ArrayLike` that can be parametrized w.r.t. `np.generic`
  64. _ArrayLike: TypeAlias = (
  65. _SupportsArray[dtype[_ScalarType]]
  66. | _NestedSequence[_SupportsArray[dtype[_ScalarType]]]
  67. )
  68. # A union representing array-like objects; consists of two typevars:
  69. # One representing types that can be parametrized w.r.t. `np.dtype`
  70. # and another one for the rest
  71. _DualArrayLike: TypeAlias = (
  72. _SupportsArray[_DType]
  73. | _NestedSequence[_SupportsArray[_DType]]
  74. | _T
  75. | _NestedSequence[_T]
  76. )
  77. if sys.version_info >= (3, 12):
  78. from collections.abc import Buffer as _Buffer
  79. else:
  80. @runtime_checkable
  81. class _Buffer(Protocol):
  82. def __buffer__(self, flags: int, /) -> memoryview: ...
  83. ArrayLike: TypeAlias = _Buffer | _DualArrayLike[
  84. dtype[Any],
  85. bool | int | float | complex | str | bytes,
  86. ]
  87. # `ArrayLike<X>_co`: array-like objects that can be coerced into `X`
  88. # given the casting rules `same_kind`
  89. _ArrayLikeBool_co: TypeAlias = _DualArrayLike[
  90. dtype[np.bool],
  91. bool,
  92. ]
  93. _ArrayLikeUInt_co: TypeAlias = _DualArrayLike[
  94. dtype[np.bool] | dtype[unsignedinteger[Any]],
  95. bool,
  96. ]
  97. _ArrayLikeInt_co: TypeAlias = _DualArrayLike[
  98. dtype[np.bool] | dtype[integer[Any]],
  99. bool | int,
  100. ]
  101. _ArrayLikeFloat_co: TypeAlias = _DualArrayLike[
  102. dtype[np.bool] | dtype[integer[Any]] | dtype[floating[Any]],
  103. bool | int | float,
  104. ]
  105. _ArrayLikeComplex_co: TypeAlias = _DualArrayLike[
  106. (
  107. dtype[np.bool]
  108. | dtype[integer[Any]]
  109. | dtype[floating[Any]]
  110. | dtype[complexfloating[Any, Any]]
  111. ),
  112. bool | int | float | complex,
  113. ]
  114. _ArrayLikeNumber_co: TypeAlias = _DualArrayLike[
  115. dtype[np.bool] | dtype[number[Any]],
  116. bool | int | float | complex,
  117. ]
  118. _ArrayLikeTD64_co: TypeAlias = _DualArrayLike[
  119. dtype[np.bool] | dtype[integer[Any]] | dtype[timedelta64],
  120. bool | int,
  121. ]
  122. _ArrayLikeDT64_co: TypeAlias = (
  123. _SupportsArray[dtype[datetime64]]
  124. | _NestedSequence[_SupportsArray[dtype[datetime64]]]
  125. )
  126. _ArrayLikeObject_co: TypeAlias = (
  127. _SupportsArray[dtype[object_]]
  128. | _NestedSequence[_SupportsArray[dtype[object_]]]
  129. )
  130. _ArrayLikeVoid_co: TypeAlias = (
  131. _SupportsArray[dtype[void]]
  132. | _NestedSequence[_SupportsArray[dtype[void]]]
  133. )
  134. _ArrayLikeStr_co: TypeAlias = _DualArrayLike[
  135. dtype[str_],
  136. str,
  137. ]
  138. _ArrayLikeBytes_co: TypeAlias = _DualArrayLike[
  139. dtype[bytes_],
  140. bytes,
  141. ]
  142. _ArrayLikeString_co: TypeAlias = _DualArrayLike[
  143. StringDType,
  144. str
  145. ]
  146. _ArrayLikeAnyString_co: TypeAlias = (
  147. _ArrayLikeStr_co |
  148. _ArrayLikeBytes_co |
  149. _ArrayLikeString_co
  150. )
  151. __Float64_co: TypeAlias = np.floating[_64Bit] | np.float32 | np.float16 | np.integer | np.bool
  152. __Complex128_co: TypeAlias = np.number[_64Bit] | np.number[_32Bit] | np.float16 | np.integer | np.bool
  153. _ArrayLikeFloat64_co: TypeAlias = _DualArrayLike[dtype[__Float64_co], float | int]
  154. _ArrayLikeComplex128_co: TypeAlias = _DualArrayLike[dtype[__Complex128_co], complex | float | int]
  155. # NOTE: This includes `builtins.bool`, but not `numpy.bool`.
  156. _ArrayLikeInt: TypeAlias = _DualArrayLike[
  157. dtype[integer[Any]],
  158. int,
  159. ]
  160. # Extra ArrayLike type so that pyright can deal with NDArray[Any]
  161. # Used as the first overload, should only match NDArray[Any],
  162. # not any actual types.
  163. # https://github.com/numpy/numpy/pull/22193
  164. if sys.version_info >= (3, 11):
  165. from typing import Never as _UnknownType
  166. else:
  167. from typing import NoReturn as _UnknownType
  168. _ArrayLikeUnknown: TypeAlias = _DualArrayLike[
  169. dtype[_UnknownType],
  170. _UnknownType,
  171. ]