_shape_base_impl.pyi 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. from collections.abc import Callable, Sequence
  2. from typing import (
  3. Any,
  4. Concatenate,
  5. ParamSpec,
  6. Protocol,
  7. SupportsIndex,
  8. TypeVar,
  9. overload,
  10. type_check_only,
  11. )
  12. from typing_extensions import deprecated
  13. import numpy as np
  14. from numpy import (
  15. _CastingKind,
  16. complexfloating,
  17. floating,
  18. generic,
  19. integer,
  20. object_,
  21. signedinteger,
  22. ufunc,
  23. unsignedinteger,
  24. )
  25. from numpy._typing import (
  26. ArrayLike,
  27. DTypeLike,
  28. NDArray,
  29. _ArrayLike,
  30. _ArrayLikeBool_co,
  31. _ArrayLikeComplex_co,
  32. _ArrayLikeFloat_co,
  33. _ArrayLikeInt_co,
  34. _ArrayLikeObject_co,
  35. _ArrayLikeUInt_co,
  36. _ShapeLike,
  37. )
  38. __all__ = [
  39. "column_stack",
  40. "row_stack",
  41. "dstack",
  42. "array_split",
  43. "split",
  44. "hsplit",
  45. "vsplit",
  46. "dsplit",
  47. "apply_over_axes",
  48. "expand_dims",
  49. "apply_along_axis",
  50. "kron",
  51. "tile",
  52. "take_along_axis",
  53. "put_along_axis",
  54. ]
  55. _P = ParamSpec("_P")
  56. _ScalarT = TypeVar("_ScalarT", bound=generic)
  57. # Signature of `__array_wrap__`
  58. @type_check_only
  59. class _ArrayWrap(Protocol):
  60. def __call__(
  61. self,
  62. array: NDArray[Any],
  63. context: tuple[ufunc, tuple[Any, ...], int] | None = ...,
  64. return_scalar: bool = ...,
  65. /,
  66. ) -> Any: ...
  67. @type_check_only
  68. class _SupportsArrayWrap(Protocol):
  69. @property
  70. def __array_wrap__(self) -> _ArrayWrap: ...
  71. ###
  72. def take_along_axis(
  73. arr: _ScalarT | NDArray[_ScalarT],
  74. indices: NDArray[integer],
  75. axis: int | None = -1,
  76. ) -> NDArray[_ScalarT]: ...
  77. def put_along_axis(
  78. arr: NDArray[_ScalarT],
  79. indices: NDArray[integer],
  80. values: ArrayLike,
  81. axis: int | None,
  82. ) -> None: ...
  83. @overload
  84. def apply_along_axis(
  85. func1d: Callable[Concatenate[NDArray[Any], _P], _ArrayLike[_ScalarT]],
  86. axis: SupportsIndex,
  87. arr: ArrayLike,
  88. *args: _P.args,
  89. **kwargs: _P.kwargs,
  90. ) -> NDArray[_ScalarT]: ...
  91. @overload
  92. def apply_along_axis(
  93. func1d: Callable[Concatenate[NDArray[Any], _P], Any],
  94. axis: SupportsIndex,
  95. arr: ArrayLike,
  96. *args: _P.args,
  97. **kwargs: _P.kwargs,
  98. ) -> NDArray[Any]: ...
  99. def apply_over_axes(
  100. func: Callable[[NDArray[Any], int], NDArray[_ScalarT]],
  101. a: ArrayLike,
  102. axes: int | Sequence[int],
  103. ) -> NDArray[_ScalarT]: ...
  104. @overload
  105. def expand_dims(
  106. a: _ArrayLike[_ScalarT],
  107. axis: _ShapeLike,
  108. ) -> NDArray[_ScalarT]: ...
  109. @overload
  110. def expand_dims(
  111. a: ArrayLike,
  112. axis: _ShapeLike,
  113. ) -> NDArray[Any]: ...
  114. # Deprecated in NumPy 2.0, 2023-08-18
  115. @deprecated("`row_stack` alias is deprecated. Use `np.vstack` directly.")
  116. def row_stack(
  117. tup: Sequence[ArrayLike],
  118. *,
  119. dtype: DTypeLike | None = None,
  120. casting: _CastingKind = "same_kind",
  121. ) -> NDArray[Any]: ...
  122. # keep in sync with `numpy.ma.extras.column_stack`
  123. @overload
  124. def column_stack(tup: Sequence[_ArrayLike[_ScalarT]]) -> NDArray[_ScalarT]: ...
  125. @overload
  126. def column_stack(tup: Sequence[ArrayLike]) -> NDArray[Any]: ...
  127. # keep in sync with `numpy.ma.extras.dstack`
  128. @overload
  129. def dstack(tup: Sequence[_ArrayLike[_ScalarT]]) -> NDArray[_ScalarT]: ...
  130. @overload
  131. def dstack(tup: Sequence[ArrayLike]) -> NDArray[Any]: ...
  132. @overload
  133. def array_split(
  134. ary: _ArrayLike[_ScalarT],
  135. indices_or_sections: _ShapeLike,
  136. axis: SupportsIndex = 0,
  137. ) -> list[NDArray[_ScalarT]]: ...
  138. @overload
  139. def array_split(
  140. ary: ArrayLike,
  141. indices_or_sections: _ShapeLike,
  142. axis: SupportsIndex = 0,
  143. ) -> list[NDArray[Any]]: ...
  144. @overload
  145. def split(
  146. ary: _ArrayLike[_ScalarT],
  147. indices_or_sections: _ShapeLike,
  148. axis: SupportsIndex = 0,
  149. ) -> list[NDArray[_ScalarT]]: ...
  150. @overload
  151. def split(
  152. ary: ArrayLike,
  153. indices_or_sections: _ShapeLike,
  154. axis: SupportsIndex = 0,
  155. ) -> list[NDArray[Any]]: ...
  156. # keep in sync with `numpy.ma.extras.hsplit`
  157. @overload
  158. def hsplit(
  159. ary: _ArrayLike[_ScalarT],
  160. indices_or_sections: _ShapeLike,
  161. ) -> list[NDArray[_ScalarT]]: ...
  162. @overload
  163. def hsplit(
  164. ary: ArrayLike,
  165. indices_or_sections: _ShapeLike,
  166. ) -> list[NDArray[Any]]: ...
  167. @overload
  168. def vsplit(
  169. ary: _ArrayLike[_ScalarT],
  170. indices_or_sections: _ShapeLike,
  171. ) -> list[NDArray[_ScalarT]]: ...
  172. @overload
  173. def vsplit(
  174. ary: ArrayLike,
  175. indices_or_sections: _ShapeLike,
  176. ) -> list[NDArray[Any]]: ...
  177. @overload
  178. def dsplit(
  179. ary: _ArrayLike[_ScalarT],
  180. indices_or_sections: _ShapeLike,
  181. ) -> list[NDArray[_ScalarT]]: ...
  182. @overload
  183. def dsplit(
  184. ary: ArrayLike,
  185. indices_or_sections: _ShapeLike,
  186. ) -> list[NDArray[Any]]: ...
  187. @overload
  188. def get_array_wrap(*args: _SupportsArrayWrap) -> _ArrayWrap: ...
  189. @overload
  190. def get_array_wrap(*args: object) -> _ArrayWrap | None: ...
  191. @overload
  192. def kron(a: _ArrayLikeBool_co, b: _ArrayLikeBool_co) -> NDArray[np.bool]: ... # type: ignore[misc]
  193. @overload
  194. def kron(a: _ArrayLikeUInt_co, b: _ArrayLikeUInt_co) -> NDArray[unsignedinteger]: ... # type: ignore[misc]
  195. @overload
  196. def kron(a: _ArrayLikeInt_co, b: _ArrayLikeInt_co) -> NDArray[signedinteger]: ... # type: ignore[misc]
  197. @overload
  198. def kron(a: _ArrayLikeFloat_co, b: _ArrayLikeFloat_co) -> NDArray[floating]: ... # type: ignore[misc]
  199. @overload
  200. def kron(a: _ArrayLikeComplex_co, b: _ArrayLikeComplex_co) -> NDArray[complexfloating]: ...
  201. @overload
  202. def kron(a: _ArrayLikeObject_co, b: Any) -> NDArray[object_]: ...
  203. @overload
  204. def kron(a: Any, b: _ArrayLikeObject_co) -> NDArray[object_]: ...
  205. @overload
  206. def tile(
  207. A: _ArrayLike[_ScalarT],
  208. reps: int | Sequence[int],
  209. ) -> NDArray[_ScalarT]: ...
  210. @overload
  211. def tile(
  212. A: ArrayLike,
  213. reps: int | Sequence[int],
  214. ) -> NDArray[Any]: ...