_array_api_info.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. """
  2. Array API Inspection namespace
  3. This is the namespace for inspection functions as defined by the array API
  4. standard. See
  5. https://data-apis.org/array-api/latest/API_specification/inspection.html for
  6. more details.
  7. """
  8. from numpy._core import (
  9. bool,
  10. complex64,
  11. complex128,
  12. dtype,
  13. float32,
  14. float64,
  15. int8,
  16. int16,
  17. int32,
  18. int64,
  19. intp,
  20. uint8,
  21. uint16,
  22. uint32,
  23. uint64,
  24. )
  25. from numpy._utils import set_module
  26. @set_module('numpy')
  27. class __array_namespace_info__:
  28. """
  29. Get the array API inspection namespace for NumPy.
  30. The array API inspection namespace defines the following functions:
  31. - capabilities()
  32. - default_device()
  33. - default_dtypes()
  34. - dtypes()
  35. - devices()
  36. See
  37. https://data-apis.org/array-api/latest/API_specification/inspection.html
  38. for more details.
  39. Returns
  40. -------
  41. info : ModuleType
  42. The array API inspection namespace for NumPy.
  43. Examples
  44. --------
  45. >>> info = np.__array_namespace_info__()
  46. >>> info.default_dtypes()
  47. {'real floating': numpy.float64,
  48. 'complex floating': numpy.complex128,
  49. 'integral': numpy.int64,
  50. 'indexing': numpy.int64}
  51. """
  52. def capabilities(self):
  53. """
  54. Return a dictionary of array API library capabilities.
  55. The resulting dictionary has the following keys:
  56. - **"boolean indexing"**: boolean indicating whether an array library
  57. supports boolean indexing. Always ``True`` for NumPy.
  58. - **"data-dependent shapes"**: boolean indicating whether an array
  59. library supports data-dependent output shapes. Always ``True`` for
  60. NumPy.
  61. See
  62. https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html
  63. for more details.
  64. See Also
  65. --------
  66. __array_namespace_info__.default_device,
  67. __array_namespace_info__.default_dtypes,
  68. __array_namespace_info__.dtypes,
  69. __array_namespace_info__.devices
  70. Returns
  71. -------
  72. capabilities : dict
  73. A dictionary of array API library capabilities.
  74. Examples
  75. --------
  76. >>> info = np.__array_namespace_info__()
  77. >>> info.capabilities()
  78. {'boolean indexing': True,
  79. 'data-dependent shapes': True,
  80. 'max dimensions': 64}
  81. """
  82. return {
  83. "boolean indexing": True,
  84. "data-dependent shapes": True,
  85. "max dimensions": 64,
  86. }
  87. def default_device(self):
  88. """
  89. The default device used for new NumPy arrays.
  90. For NumPy, this always returns ``'cpu'``.
  91. See Also
  92. --------
  93. __array_namespace_info__.capabilities,
  94. __array_namespace_info__.default_dtypes,
  95. __array_namespace_info__.dtypes,
  96. __array_namespace_info__.devices
  97. Returns
  98. -------
  99. device : str
  100. The default device used for new NumPy arrays.
  101. Examples
  102. --------
  103. >>> info = np.__array_namespace_info__()
  104. >>> info.default_device()
  105. 'cpu'
  106. """
  107. return "cpu"
  108. def default_dtypes(self, *, device=None):
  109. """
  110. The default data types used for new NumPy arrays.
  111. For NumPy, this always returns the following dictionary:
  112. - **"real floating"**: ``numpy.float64``
  113. - **"complex floating"**: ``numpy.complex128``
  114. - **"integral"**: ``numpy.intp``
  115. - **"indexing"**: ``numpy.intp``
  116. Parameters
  117. ----------
  118. device : str, optional
  119. The device to get the default data types for. For NumPy, only
  120. ``'cpu'`` is allowed.
  121. Returns
  122. -------
  123. dtypes : dict
  124. A dictionary describing the default data types used for new NumPy
  125. arrays.
  126. See Also
  127. --------
  128. __array_namespace_info__.capabilities,
  129. __array_namespace_info__.default_device,
  130. __array_namespace_info__.dtypes,
  131. __array_namespace_info__.devices
  132. Examples
  133. --------
  134. >>> info = np.__array_namespace_info__()
  135. >>> info.default_dtypes()
  136. {'real floating': numpy.float64,
  137. 'complex floating': numpy.complex128,
  138. 'integral': numpy.int64,
  139. 'indexing': numpy.int64}
  140. """
  141. if device not in ["cpu", None]:
  142. raise ValueError(
  143. 'Device not understood. Only "cpu" is allowed, but received:'
  144. f' {device}'
  145. )
  146. return {
  147. "real floating": dtype(float64),
  148. "complex floating": dtype(complex128),
  149. "integral": dtype(intp),
  150. "indexing": dtype(intp),
  151. }
  152. def dtypes(self, *, device=None, kind=None):
  153. """
  154. The array API data types supported by NumPy.
  155. Note that this function only returns data types that are defined by
  156. the array API.
  157. Parameters
  158. ----------
  159. device : str, optional
  160. The device to get the data types for. For NumPy, only ``'cpu'`` is
  161. allowed.
  162. kind : str or tuple of str, optional
  163. The kind of data types to return. If ``None``, all data types are
  164. returned. If a string, only data types of that kind are returned.
  165. If a tuple, a dictionary containing the union of the given kinds
  166. is returned. The following kinds are supported:
  167. - ``'bool'``: boolean data types (i.e., ``bool``).
  168. - ``'signed integer'``: signed integer data types (i.e., ``int8``,
  169. ``int16``, ``int32``, ``int64``).
  170. - ``'unsigned integer'``: unsigned integer data types (i.e.,
  171. ``uint8``, ``uint16``, ``uint32``, ``uint64``).
  172. - ``'integral'``: integer data types. Shorthand for ``('signed
  173. integer', 'unsigned integer')``.
  174. - ``'real floating'``: real-valued floating-point data types
  175. (i.e., ``float32``, ``float64``).
  176. - ``'complex floating'``: complex floating-point data types (i.e.,
  177. ``complex64``, ``complex128``).
  178. - ``'numeric'``: numeric data types. Shorthand for ``('integral',
  179. 'real floating', 'complex floating')``.
  180. Returns
  181. -------
  182. dtypes : dict
  183. A dictionary mapping the names of data types to the corresponding
  184. NumPy data types.
  185. See Also
  186. --------
  187. __array_namespace_info__.capabilities,
  188. __array_namespace_info__.default_device,
  189. __array_namespace_info__.default_dtypes,
  190. __array_namespace_info__.devices
  191. Examples
  192. --------
  193. >>> info = np.__array_namespace_info__()
  194. >>> info.dtypes(kind='signed integer')
  195. {'int8': numpy.int8,
  196. 'int16': numpy.int16,
  197. 'int32': numpy.int32,
  198. 'int64': numpy.int64}
  199. """
  200. if device not in ["cpu", None]:
  201. raise ValueError(
  202. 'Device not understood. Only "cpu" is allowed, but received:'
  203. f' {device}'
  204. )
  205. if kind is None:
  206. return {
  207. "bool": dtype(bool),
  208. "int8": dtype(int8),
  209. "int16": dtype(int16),
  210. "int32": dtype(int32),
  211. "int64": dtype(int64),
  212. "uint8": dtype(uint8),
  213. "uint16": dtype(uint16),
  214. "uint32": dtype(uint32),
  215. "uint64": dtype(uint64),
  216. "float32": dtype(float32),
  217. "float64": dtype(float64),
  218. "complex64": dtype(complex64),
  219. "complex128": dtype(complex128),
  220. }
  221. if kind == "bool":
  222. return {"bool": bool}
  223. if kind == "signed integer":
  224. return {
  225. "int8": dtype(int8),
  226. "int16": dtype(int16),
  227. "int32": dtype(int32),
  228. "int64": dtype(int64),
  229. }
  230. if kind == "unsigned integer":
  231. return {
  232. "uint8": dtype(uint8),
  233. "uint16": dtype(uint16),
  234. "uint32": dtype(uint32),
  235. "uint64": dtype(uint64),
  236. }
  237. if kind == "integral":
  238. return {
  239. "int8": dtype(int8),
  240. "int16": dtype(int16),
  241. "int32": dtype(int32),
  242. "int64": dtype(int64),
  243. "uint8": dtype(uint8),
  244. "uint16": dtype(uint16),
  245. "uint32": dtype(uint32),
  246. "uint64": dtype(uint64),
  247. }
  248. if kind == "real floating":
  249. return {
  250. "float32": dtype(float32),
  251. "float64": dtype(float64),
  252. }
  253. if kind == "complex floating":
  254. return {
  255. "complex64": dtype(complex64),
  256. "complex128": dtype(complex128),
  257. }
  258. if kind == "numeric":
  259. return {
  260. "int8": dtype(int8),
  261. "int16": dtype(int16),
  262. "int32": dtype(int32),
  263. "int64": dtype(int64),
  264. "uint8": dtype(uint8),
  265. "uint16": dtype(uint16),
  266. "uint32": dtype(uint32),
  267. "uint64": dtype(uint64),
  268. "float32": dtype(float32),
  269. "float64": dtype(float64),
  270. "complex64": dtype(complex64),
  271. "complex128": dtype(complex128),
  272. }
  273. if isinstance(kind, tuple):
  274. res = {}
  275. for k in kind:
  276. res.update(self.dtypes(kind=k))
  277. return res
  278. raise ValueError(f"unsupported kind: {kind!r}")
  279. def devices(self):
  280. """
  281. The devices supported by NumPy.
  282. For NumPy, this always returns ``['cpu']``.
  283. Returns
  284. -------
  285. devices : list of str
  286. The devices supported by NumPy.
  287. See Also
  288. --------
  289. __array_namespace_info__.capabilities,
  290. __array_namespace_info__.default_device,
  291. __array_namespace_info__.default_dtypes,
  292. __array_namespace_info__.dtypes
  293. Examples
  294. --------
  295. >>> info = np.__array_namespace_info__()
  296. >>> info.devices()
  297. ['cpu']
  298. """
  299. return ["cpu"]