_shape_base_impl.pyi 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. from collections.abc import Callable, Sequence
  2. from typing import (
  3. TypeVar,
  4. Any,
  5. overload,
  6. SupportsIndex,
  7. Protocol,
  8. ParamSpec,
  9. Concatenate,
  10. type_check_only,
  11. )
  12. from typing_extensions import deprecated
  13. import numpy as np
  14. from numpy import _CastingKind, generic, integer, ufunc, unsignedinteger, signedinteger, floating, complexfloating, object_
  15. from numpy._typing import (
  16. ArrayLike,
  17. DTypeLike,
  18. NDArray,
  19. _ShapeLike,
  20. _ArrayLike,
  21. _ArrayLikeBool_co,
  22. _ArrayLikeUInt_co,
  23. _ArrayLikeInt_co,
  24. _ArrayLikeFloat_co,
  25. _ArrayLikeComplex_co,
  26. _ArrayLikeObject_co,
  27. )
  28. __all__ = [
  29. "column_stack",
  30. "row_stack",
  31. "dstack",
  32. "array_split",
  33. "split",
  34. "hsplit",
  35. "vsplit",
  36. "dsplit",
  37. "apply_over_axes",
  38. "expand_dims",
  39. "apply_along_axis",
  40. "kron",
  41. "tile",
  42. "take_along_axis",
  43. "put_along_axis",
  44. ]
  45. _P = ParamSpec("_P")
  46. _SCT = TypeVar("_SCT", bound=generic)
  47. # Signature of `__array_wrap__`
  48. @type_check_only
  49. class _ArrayWrap(Protocol):
  50. def __call__(
  51. self,
  52. array: NDArray[Any],
  53. context: None | tuple[ufunc, tuple[Any, ...], int] = ...,
  54. return_scalar: bool = ...,
  55. /,
  56. ) -> Any: ...
  57. @type_check_only
  58. class _SupportsArrayWrap(Protocol):
  59. @property
  60. def __array_wrap__(self) -> _ArrayWrap: ...
  61. ###
  62. def take_along_axis(
  63. arr: _SCT | NDArray[_SCT],
  64. indices: NDArray[integer[Any]],
  65. axis: None | int,
  66. ) -> NDArray[_SCT]: ...
  67. def put_along_axis(
  68. arr: NDArray[_SCT],
  69. indices: NDArray[integer[Any]],
  70. values: ArrayLike,
  71. axis: None | int,
  72. ) -> None: ...
  73. @overload
  74. def apply_along_axis(
  75. func1d: Callable[Concatenate[NDArray[Any], _P], _ArrayLike[_SCT]],
  76. axis: SupportsIndex,
  77. arr: ArrayLike,
  78. *args: _P.args,
  79. **kwargs: _P.kwargs,
  80. ) -> NDArray[_SCT]: ...
  81. @overload
  82. def apply_along_axis(
  83. func1d: Callable[Concatenate[NDArray[Any], _P], Any],
  84. axis: SupportsIndex,
  85. arr: ArrayLike,
  86. *args: _P.args,
  87. **kwargs: _P.kwargs,
  88. ) -> NDArray[Any]: ...
  89. def apply_over_axes(
  90. func: Callable[[NDArray[Any], int], NDArray[_SCT]],
  91. a: ArrayLike,
  92. axes: int | Sequence[int],
  93. ) -> NDArray[_SCT]: ...
  94. @overload
  95. def expand_dims(
  96. a: _ArrayLike[_SCT],
  97. axis: _ShapeLike,
  98. ) -> NDArray[_SCT]: ...
  99. @overload
  100. def expand_dims(
  101. a: ArrayLike,
  102. axis: _ShapeLike,
  103. ) -> NDArray[Any]: ...
  104. # Deprecated in NumPy 2.0, 2023-08-18
  105. @deprecated("`row_stack` alias is deprecated. Use `np.vstack` directly.")
  106. def row_stack(
  107. tup: Sequence[ArrayLike],
  108. *,
  109. dtype: DTypeLike | None = None,
  110. casting: _CastingKind = "same_kind",
  111. ) -> NDArray[Any]: ...
  112. #
  113. @overload
  114. def column_stack(tup: Sequence[_ArrayLike[_SCT]]) -> NDArray[_SCT]: ...
  115. @overload
  116. def column_stack(tup: Sequence[ArrayLike]) -> NDArray[Any]: ...
  117. @overload
  118. def dstack(tup: Sequence[_ArrayLike[_SCT]]) -> NDArray[_SCT]: ...
  119. @overload
  120. def dstack(tup: Sequence[ArrayLike]) -> NDArray[Any]: ...
  121. @overload
  122. def array_split(
  123. ary: _ArrayLike[_SCT],
  124. indices_or_sections: _ShapeLike,
  125. axis: SupportsIndex = ...,
  126. ) -> list[NDArray[_SCT]]: ...
  127. @overload
  128. def array_split(
  129. ary: ArrayLike,
  130. indices_or_sections: _ShapeLike,
  131. axis: SupportsIndex = ...,
  132. ) -> list[NDArray[Any]]: ...
  133. @overload
  134. def split(
  135. ary: _ArrayLike[_SCT],
  136. indices_or_sections: _ShapeLike,
  137. axis: SupportsIndex = ...,
  138. ) -> list[NDArray[_SCT]]: ...
  139. @overload
  140. def split(
  141. ary: ArrayLike,
  142. indices_or_sections: _ShapeLike,
  143. axis: SupportsIndex = ...,
  144. ) -> list[NDArray[Any]]: ...
  145. @overload
  146. def hsplit(
  147. ary: _ArrayLike[_SCT],
  148. indices_or_sections: _ShapeLike,
  149. ) -> list[NDArray[_SCT]]: ...
  150. @overload
  151. def hsplit(
  152. ary: ArrayLike,
  153. indices_or_sections: _ShapeLike,
  154. ) -> list[NDArray[Any]]: ...
  155. @overload
  156. def vsplit(
  157. ary: _ArrayLike[_SCT],
  158. indices_or_sections: _ShapeLike,
  159. ) -> list[NDArray[_SCT]]: ...
  160. @overload
  161. def vsplit(
  162. ary: ArrayLike,
  163. indices_or_sections: _ShapeLike,
  164. ) -> list[NDArray[Any]]: ...
  165. @overload
  166. def dsplit(
  167. ary: _ArrayLike[_SCT],
  168. indices_or_sections: _ShapeLike,
  169. ) -> list[NDArray[_SCT]]: ...
  170. @overload
  171. def dsplit(
  172. ary: ArrayLike,
  173. indices_or_sections: _ShapeLike,
  174. ) -> list[NDArray[Any]]: ...
  175. @overload
  176. def get_array_wrap(*args: _SupportsArrayWrap) -> _ArrayWrap: ...
  177. @overload
  178. def get_array_wrap(*args: object) -> None | _ArrayWrap: ...
  179. @overload
  180. def kron(a: _ArrayLikeBool_co, b: _ArrayLikeBool_co) -> NDArray[np.bool]: ... # type: ignore[misc]
  181. @overload
  182. def kron(a: _ArrayLikeUInt_co, b: _ArrayLikeUInt_co) -> NDArray[unsignedinteger[Any]]: ... # type: ignore[misc]
  183. @overload
  184. def kron(a: _ArrayLikeInt_co, b: _ArrayLikeInt_co) -> NDArray[signedinteger[Any]]: ... # type: ignore[misc]
  185. @overload
  186. def kron(a: _ArrayLikeFloat_co, b: _ArrayLikeFloat_co) -> NDArray[floating[Any]]: ... # type: ignore[misc]
  187. @overload
  188. def kron(a: _ArrayLikeComplex_co, b: _ArrayLikeComplex_co) -> NDArray[complexfloating[Any, Any]]: ...
  189. @overload
  190. def kron(a: _ArrayLikeObject_co, b: Any) -> NDArray[object_]: ...
  191. @overload
  192. def kron(a: Any, b: _ArrayLikeObject_co) -> NDArray[object_]: ...
  193. @overload
  194. def tile(
  195. A: _ArrayLike[_SCT],
  196. reps: int | Sequence[int],
  197. ) -> NDArray[_SCT]: ...
  198. @overload
  199. def tile(
  200. A: ArrayLike,
  201. reps: int | Sequence[int],
  202. ) -> NDArray[Any]: ...