_array_api_info.pyi 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. from typing import (
  2. ClassVar,
  3. Literal,
  4. TypeAlias,
  5. TypedDict,
  6. TypeVar,
  7. final,
  8. overload,
  9. type_check_only,
  10. )
  11. from typing_extensions import Never
  12. import numpy as np
  13. _Device: TypeAlias = Literal["cpu"]
  14. _DeviceLike: TypeAlias = None | _Device
  15. _Capabilities = TypedDict(
  16. "_Capabilities",
  17. {
  18. "boolean indexing": Literal[True],
  19. "data-dependent shapes": Literal[True],
  20. },
  21. )
  22. _DefaultDTypes = TypedDict(
  23. "_DefaultDTypes",
  24. {
  25. "real floating": np.dtype[np.float64],
  26. "complex floating": np.dtype[np.complex128],
  27. "integral": np.dtype[np.intp],
  28. "indexing": np.dtype[np.intp],
  29. },
  30. )
  31. _KindBool: TypeAlias = Literal["bool"]
  32. _KindInt: TypeAlias = Literal["signed integer"]
  33. _KindUInt: TypeAlias = Literal["unsigned integer"]
  34. _KindInteger: TypeAlias = Literal["integral"]
  35. _KindFloat: TypeAlias = Literal["real floating"]
  36. _KindComplex: TypeAlias = Literal["complex floating"]
  37. _KindNumber: TypeAlias = Literal["numeric"]
  38. _Kind: TypeAlias = (
  39. _KindBool
  40. | _KindInt
  41. | _KindUInt
  42. | _KindInteger
  43. | _KindFloat
  44. | _KindComplex
  45. | _KindNumber
  46. )
  47. _T1 = TypeVar("_T1")
  48. _T2 = TypeVar("_T2")
  49. _T3 = TypeVar("_T3")
  50. _Permute1: TypeAlias = _T1 | tuple[_T1]
  51. _Permute2: TypeAlias = tuple[_T1, _T2] | tuple[_T2, _T1]
  52. _Permute3: TypeAlias = (
  53. tuple[_T1, _T2, _T3] | tuple[_T1, _T3, _T2]
  54. | tuple[_T2, _T1, _T3] | tuple[_T2, _T3, _T1]
  55. | tuple[_T3, _T1, _T2] | tuple[_T3, _T2, _T1]
  56. )
  57. @type_check_only
  58. class _DTypesBool(TypedDict):
  59. bool: np.dtype[np.bool]
  60. @type_check_only
  61. class _DTypesInt(TypedDict):
  62. int8: np.dtype[np.int8]
  63. int16: np.dtype[np.int16]
  64. int32: np.dtype[np.int32]
  65. int64: np.dtype[np.int64]
  66. @type_check_only
  67. class _DTypesUInt(TypedDict):
  68. uint8: np.dtype[np.uint8]
  69. uint16: np.dtype[np.uint16]
  70. uint32: np.dtype[np.uint32]
  71. uint64: np.dtype[np.uint64]
  72. @type_check_only
  73. class _DTypesInteger(_DTypesInt, _DTypesUInt): ...
  74. @type_check_only
  75. class _DTypesFloat(TypedDict):
  76. float32: np.dtype[np.float32]
  77. float64: np.dtype[np.float64]
  78. @type_check_only
  79. class _DTypesComplex(TypedDict):
  80. complex64: np.dtype[np.complex64]
  81. complex128: np.dtype[np.complex128]
  82. @type_check_only
  83. class _DTypesNumber(_DTypesInteger, _DTypesFloat, _DTypesComplex): ...
  84. @type_check_only
  85. class _DTypes(_DTypesBool, _DTypesNumber): ...
  86. @type_check_only
  87. class _DTypesUnion(TypedDict, total=False):
  88. bool: np.dtype[np.bool]
  89. int8: np.dtype[np.int8]
  90. int16: np.dtype[np.int16]
  91. int32: np.dtype[np.int32]
  92. int64: np.dtype[np.int64]
  93. uint8: np.dtype[np.uint8]
  94. uint16: np.dtype[np.uint16]
  95. uint32: np.dtype[np.uint32]
  96. uint64: np.dtype[np.uint64]
  97. float32: np.dtype[np.float32]
  98. float64: np.dtype[np.float64]
  99. complex64: np.dtype[np.complex64]
  100. complex128: np.dtype[np.complex128]
  101. _EmptyDict: TypeAlias = dict[Never, Never]
  102. @final
  103. class __array_namespace_info__:
  104. __module__: ClassVar[Literal['numpy']]
  105. def capabilities(self) -> _Capabilities: ...
  106. def default_device(self) -> _Device: ...
  107. def default_dtypes(
  108. self,
  109. *,
  110. device: _DeviceLike = ...,
  111. ) -> _DefaultDTypes: ...
  112. def devices(self) -> list[_Device]: ...
  113. @overload
  114. def dtypes(
  115. self,
  116. *,
  117. device: _DeviceLike = ...,
  118. kind: None = ...,
  119. ) -> _DTypes: ...
  120. @overload
  121. def dtypes(
  122. self,
  123. *,
  124. device: _DeviceLike = ...,
  125. kind: _Permute1[_KindBool],
  126. ) -> _DTypesBool: ...
  127. @overload
  128. def dtypes(
  129. self,
  130. *,
  131. device: _DeviceLike = ...,
  132. kind: _Permute1[_KindInt],
  133. ) -> _DTypesInt: ...
  134. @overload
  135. def dtypes(
  136. self,
  137. *,
  138. device: _DeviceLike = ...,
  139. kind: _Permute1[_KindUInt],
  140. ) -> _DTypesUInt: ...
  141. @overload
  142. def dtypes(
  143. self,
  144. *,
  145. device: _DeviceLike = ...,
  146. kind: _Permute1[_KindFloat],
  147. ) -> _DTypesFloat: ...
  148. @overload
  149. def dtypes(
  150. self,
  151. *,
  152. device: _DeviceLike = ...,
  153. kind: _Permute1[_KindComplex],
  154. ) -> _DTypesComplex: ...
  155. @overload
  156. def dtypes(
  157. self,
  158. *,
  159. device: _DeviceLike = ...,
  160. kind: (
  161. _Permute1[_KindInteger]
  162. | _Permute2[_KindInt, _KindUInt]
  163. ),
  164. ) -> _DTypesInteger: ...
  165. @overload
  166. def dtypes(
  167. self,
  168. *,
  169. device: _DeviceLike = ...,
  170. kind: (
  171. _Permute1[_KindNumber]
  172. | _Permute3[_KindInteger, _KindFloat, _KindComplex]
  173. ),
  174. ) -> _DTypesNumber: ...
  175. @overload
  176. def dtypes(
  177. self,
  178. *,
  179. device: _DeviceLike = ...,
  180. kind: tuple[()],
  181. ) -> _EmptyDict: ...
  182. @overload
  183. def dtypes(
  184. self,
  185. *,
  186. device: _DeviceLike = ...,
  187. kind: tuple[_Kind, ...],
  188. ) -> _DTypesUnion: ...