_linalg.pyi 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482
  1. from collections.abc import Iterable
  2. from typing import (
  3. Literal as L,
  4. overload,
  5. TypeAlias,
  6. TypeVar,
  7. Any,
  8. SupportsIndex,
  9. SupportsInt,
  10. NamedTuple,
  11. )
  12. import numpy as np
  13. from numpy import (
  14. # re-exports
  15. vecdot,
  16. # other
  17. floating,
  18. complexfloating,
  19. signedinteger,
  20. unsignedinteger,
  21. timedelta64,
  22. object_,
  23. int32,
  24. float64,
  25. complex128,
  26. )
  27. from numpy.linalg import LinAlgError
  28. from numpy._core.fromnumeric import matrix_transpose
  29. from numpy._core.numeric import tensordot
  30. from numpy._typing import (
  31. NDArray,
  32. ArrayLike,
  33. DTypeLike,
  34. _ArrayLikeUnknown,
  35. _ArrayLikeBool_co,
  36. _ArrayLikeInt_co,
  37. _ArrayLikeUInt_co,
  38. _ArrayLikeFloat_co,
  39. _ArrayLikeComplex_co,
  40. _ArrayLikeTD64_co,
  41. _ArrayLikeObject_co,
  42. )
  43. __all__ = [
  44. "matrix_power",
  45. "solve",
  46. "tensorsolve",
  47. "tensorinv",
  48. "inv",
  49. "cholesky",
  50. "eigvals",
  51. "eigvalsh",
  52. "pinv",
  53. "slogdet",
  54. "det",
  55. "svd",
  56. "svdvals",
  57. "eig",
  58. "eigh",
  59. "lstsq",
  60. "norm",
  61. "qr",
  62. "cond",
  63. "matrix_rank",
  64. "LinAlgError",
  65. "multi_dot",
  66. "trace",
  67. "diagonal",
  68. "cross",
  69. "outer",
  70. "tensordot",
  71. "matmul",
  72. "matrix_transpose",
  73. "matrix_norm",
  74. "vector_norm",
  75. "vecdot",
  76. ]
  77. _ArrayType = TypeVar("_ArrayType", bound=NDArray[Any])
  78. _ModeKind: TypeAlias = L["reduced", "complete", "r", "raw"]
  79. ###
  80. fortran_int = np.intc
  81. class EigResult(NamedTuple):
  82. eigenvalues: NDArray[Any]
  83. eigenvectors: NDArray[Any]
  84. class EighResult(NamedTuple):
  85. eigenvalues: NDArray[Any]
  86. eigenvectors: NDArray[Any]
  87. class QRResult(NamedTuple):
  88. Q: NDArray[Any]
  89. R: NDArray[Any]
  90. class SlogdetResult(NamedTuple):
  91. # TODO: `sign` and `logabsdet` are scalars for input 2D arrays and
  92. # a `(x.ndim - 2)`` dimensionl arrays otherwise
  93. sign: Any
  94. logabsdet: Any
  95. class SVDResult(NamedTuple):
  96. U: NDArray[Any]
  97. S: NDArray[Any]
  98. Vh: NDArray[Any]
  99. @overload
  100. def tensorsolve(
  101. a: _ArrayLikeInt_co,
  102. b: _ArrayLikeInt_co,
  103. axes: None | Iterable[int] =...,
  104. ) -> NDArray[float64]: ...
  105. @overload
  106. def tensorsolve(
  107. a: _ArrayLikeFloat_co,
  108. b: _ArrayLikeFloat_co,
  109. axes: None | Iterable[int] =...,
  110. ) -> NDArray[floating[Any]]: ...
  111. @overload
  112. def tensorsolve(
  113. a: _ArrayLikeComplex_co,
  114. b: _ArrayLikeComplex_co,
  115. axes: None | Iterable[int] =...,
  116. ) -> NDArray[complexfloating[Any, Any]]: ...
  117. @overload
  118. def solve(
  119. a: _ArrayLikeInt_co,
  120. b: _ArrayLikeInt_co,
  121. ) -> NDArray[float64]: ...
  122. @overload
  123. def solve(
  124. a: _ArrayLikeFloat_co,
  125. b: _ArrayLikeFloat_co,
  126. ) -> NDArray[floating[Any]]: ...
  127. @overload
  128. def solve(
  129. a: _ArrayLikeComplex_co,
  130. b: _ArrayLikeComplex_co,
  131. ) -> NDArray[complexfloating[Any, Any]]: ...
  132. @overload
  133. def tensorinv(
  134. a: _ArrayLikeInt_co,
  135. ind: int = ...,
  136. ) -> NDArray[float64]: ...
  137. @overload
  138. def tensorinv(
  139. a: _ArrayLikeFloat_co,
  140. ind: int = ...,
  141. ) -> NDArray[floating[Any]]: ...
  142. @overload
  143. def tensorinv(
  144. a: _ArrayLikeComplex_co,
  145. ind: int = ...,
  146. ) -> NDArray[complexfloating[Any, Any]]: ...
  147. @overload
  148. def inv(a: _ArrayLikeInt_co) -> NDArray[float64]: ...
  149. @overload
  150. def inv(a: _ArrayLikeFloat_co) -> NDArray[floating[Any]]: ...
  151. @overload
  152. def inv(a: _ArrayLikeComplex_co) -> NDArray[complexfloating[Any, Any]]: ...
  153. # TODO: The supported input and output dtypes are dependent on the value of `n`.
  154. # For example: `n < 0` always casts integer types to float64
  155. def matrix_power(
  156. a: _ArrayLikeComplex_co | _ArrayLikeObject_co,
  157. n: SupportsIndex,
  158. ) -> NDArray[Any]: ...
  159. @overload
  160. def cholesky(a: _ArrayLikeInt_co, /, *, upper: bool = False) -> NDArray[float64]: ...
  161. @overload
  162. def cholesky(a: _ArrayLikeFloat_co, /, *, upper: bool = False) -> NDArray[floating[Any]]: ...
  163. @overload
  164. def cholesky(a: _ArrayLikeComplex_co, /, *, upper: bool = False) -> NDArray[complexfloating[Any, Any]]: ...
  165. @overload
  166. def outer(x1: _ArrayLikeUnknown, x2: _ArrayLikeUnknown) -> NDArray[Any]: ...
  167. @overload
  168. def outer(x1: _ArrayLikeBool_co, x2: _ArrayLikeBool_co) -> NDArray[np.bool]: ...
  169. @overload
  170. def outer(x1: _ArrayLikeUInt_co, x2: _ArrayLikeUInt_co) -> NDArray[unsignedinteger[Any]]: ...
  171. @overload
  172. def outer(x1: _ArrayLikeInt_co, x2: _ArrayLikeInt_co) -> NDArray[signedinteger[Any]]: ...
  173. @overload
  174. def outer(x1: _ArrayLikeFloat_co, x2: _ArrayLikeFloat_co) -> NDArray[floating[Any]]: ...
  175. @overload
  176. def outer(
  177. x1: _ArrayLikeComplex_co,
  178. x2: _ArrayLikeComplex_co,
  179. ) -> NDArray[complexfloating[Any, Any]]: ...
  180. @overload
  181. def outer(
  182. x1: _ArrayLikeTD64_co,
  183. x2: _ArrayLikeTD64_co,
  184. out: None = ...,
  185. ) -> NDArray[timedelta64]: ...
  186. @overload
  187. def outer(x1: _ArrayLikeObject_co, x2: _ArrayLikeObject_co) -> NDArray[object_]: ...
  188. @overload
  189. def outer(
  190. x1: _ArrayLikeComplex_co | _ArrayLikeTD64_co | _ArrayLikeObject_co,
  191. x2: _ArrayLikeComplex_co | _ArrayLikeTD64_co | _ArrayLikeObject_co,
  192. ) -> _ArrayType: ...
  193. @overload
  194. def qr(a: _ArrayLikeInt_co, mode: _ModeKind = ...) -> QRResult: ...
  195. @overload
  196. def qr(a: _ArrayLikeFloat_co, mode: _ModeKind = ...) -> QRResult: ...
  197. @overload
  198. def qr(a: _ArrayLikeComplex_co, mode: _ModeKind = ...) -> QRResult: ...
  199. @overload
  200. def eigvals(a: _ArrayLikeInt_co) -> NDArray[float64] | NDArray[complex128]: ...
  201. @overload
  202. def eigvals(a: _ArrayLikeFloat_co) -> NDArray[floating[Any]] | NDArray[complexfloating[Any, Any]]: ...
  203. @overload
  204. def eigvals(a: _ArrayLikeComplex_co) -> NDArray[complexfloating[Any, Any]]: ...
  205. @overload
  206. def eigvalsh(a: _ArrayLikeInt_co, UPLO: L["L", "U", "l", "u"] = ...) -> NDArray[float64]: ...
  207. @overload
  208. def eigvalsh(a: _ArrayLikeComplex_co, UPLO: L["L", "U", "l", "u"] = ...) -> NDArray[floating[Any]]: ...
  209. @overload
  210. def eig(a: _ArrayLikeInt_co) -> EigResult: ...
  211. @overload
  212. def eig(a: _ArrayLikeFloat_co) -> EigResult: ...
  213. @overload
  214. def eig(a: _ArrayLikeComplex_co) -> EigResult: ...
  215. @overload
  216. def eigh(
  217. a: _ArrayLikeInt_co,
  218. UPLO: L["L", "U", "l", "u"] = ...,
  219. ) -> EighResult: ...
  220. @overload
  221. def eigh(
  222. a: _ArrayLikeFloat_co,
  223. UPLO: L["L", "U", "l", "u"] = ...,
  224. ) -> EighResult: ...
  225. @overload
  226. def eigh(
  227. a: _ArrayLikeComplex_co,
  228. UPLO: L["L", "U", "l", "u"] = ...,
  229. ) -> EighResult: ...
  230. @overload
  231. def svd(
  232. a: _ArrayLikeInt_co,
  233. full_matrices: bool = ...,
  234. compute_uv: L[True] = ...,
  235. hermitian: bool = ...,
  236. ) -> SVDResult: ...
  237. @overload
  238. def svd(
  239. a: _ArrayLikeFloat_co,
  240. full_matrices: bool = ...,
  241. compute_uv: L[True] = ...,
  242. hermitian: bool = ...,
  243. ) -> SVDResult: ...
  244. @overload
  245. def svd(
  246. a: _ArrayLikeComplex_co,
  247. full_matrices: bool = ...,
  248. compute_uv: L[True] = ...,
  249. hermitian: bool = ...,
  250. ) -> SVDResult: ...
  251. @overload
  252. def svd(
  253. a: _ArrayLikeInt_co,
  254. full_matrices: bool = ...,
  255. compute_uv: L[False] = ...,
  256. hermitian: bool = ...,
  257. ) -> NDArray[float64]: ...
  258. @overload
  259. def svd(
  260. a: _ArrayLikeComplex_co,
  261. full_matrices: bool = ...,
  262. compute_uv: L[False] = ...,
  263. hermitian: bool = ...,
  264. ) -> NDArray[floating[Any]]: ...
  265. def svdvals(
  266. x: _ArrayLikeInt_co | _ArrayLikeFloat_co | _ArrayLikeComplex_co
  267. ) -> NDArray[floating[Any]]: ...
  268. # TODO: Returns a scalar for 2D arrays and
  269. # a `(x.ndim - 2)`` dimensionl array otherwise
  270. def cond(x: _ArrayLikeComplex_co, p: None | float | L["fro", "nuc"] = ...) -> Any: ...
  271. # TODO: Returns `int` for <2D arrays and `intp` otherwise
  272. def matrix_rank(
  273. A: _ArrayLikeComplex_co,
  274. tol: None | _ArrayLikeFloat_co = ...,
  275. hermitian: bool = ...,
  276. *,
  277. rtol: None | _ArrayLikeFloat_co = ...,
  278. ) -> Any: ...
  279. @overload
  280. def pinv(
  281. a: _ArrayLikeInt_co,
  282. rcond: _ArrayLikeFloat_co = ...,
  283. hermitian: bool = ...,
  284. ) -> NDArray[float64]: ...
  285. @overload
  286. def pinv(
  287. a: _ArrayLikeFloat_co,
  288. rcond: _ArrayLikeFloat_co = ...,
  289. hermitian: bool = ...,
  290. ) -> NDArray[floating[Any]]: ...
  291. @overload
  292. def pinv(
  293. a: _ArrayLikeComplex_co,
  294. rcond: _ArrayLikeFloat_co = ...,
  295. hermitian: bool = ...,
  296. ) -> NDArray[complexfloating[Any, Any]]: ...
  297. # TODO: Returns a 2-tuple of scalars for 2D arrays and
  298. # a 2-tuple of `(a.ndim - 2)`` dimensionl arrays otherwise
  299. def slogdet(a: _ArrayLikeComplex_co) -> SlogdetResult: ...
  300. # TODO: Returns a 2-tuple of scalars for 2D arrays and
  301. # a 2-tuple of `(a.ndim - 2)`` dimensionl arrays otherwise
  302. def det(a: _ArrayLikeComplex_co) -> Any: ...
  303. @overload
  304. def lstsq(a: _ArrayLikeInt_co, b: _ArrayLikeInt_co, rcond: None | float = ...) -> tuple[
  305. NDArray[float64],
  306. NDArray[float64],
  307. int32,
  308. NDArray[float64],
  309. ]: ...
  310. @overload
  311. def lstsq(a: _ArrayLikeFloat_co, b: _ArrayLikeFloat_co, rcond: None | float = ...) -> tuple[
  312. NDArray[floating[Any]],
  313. NDArray[floating[Any]],
  314. int32,
  315. NDArray[floating[Any]],
  316. ]: ...
  317. @overload
  318. def lstsq(a: _ArrayLikeComplex_co, b: _ArrayLikeComplex_co, rcond: None | float = ...) -> tuple[
  319. NDArray[complexfloating[Any, Any]],
  320. NDArray[floating[Any]],
  321. int32,
  322. NDArray[floating[Any]],
  323. ]: ...
  324. @overload
  325. def norm(
  326. x: ArrayLike,
  327. ord: None | float | L["fro", "nuc"] = ...,
  328. axis: None = ...,
  329. keepdims: bool = ...,
  330. ) -> floating[Any]: ...
  331. @overload
  332. def norm(
  333. x: ArrayLike,
  334. ord: None | float | L["fro", "nuc"] = ...,
  335. axis: SupportsInt | SupportsIndex | tuple[int, ...] = ...,
  336. keepdims: bool = ...,
  337. ) -> Any: ...
  338. @overload
  339. def matrix_norm(
  340. x: ArrayLike,
  341. /,
  342. *,
  343. ord: None | float | L["fro", "nuc"] = ...,
  344. keepdims: bool = ...,
  345. ) -> floating[Any]: ...
  346. @overload
  347. def matrix_norm(
  348. x: ArrayLike,
  349. /,
  350. *,
  351. ord: None | float | L["fro", "nuc"] = ...,
  352. keepdims: bool = ...,
  353. ) -> Any: ...
  354. @overload
  355. def vector_norm(
  356. x: ArrayLike,
  357. /,
  358. *,
  359. axis: None = ...,
  360. ord: None | float = ...,
  361. keepdims: bool = ...,
  362. ) -> floating[Any]: ...
  363. @overload
  364. def vector_norm(
  365. x: ArrayLike,
  366. /,
  367. *,
  368. axis: SupportsInt | SupportsIndex | tuple[int, ...] = ...,
  369. ord: None | float = ...,
  370. keepdims: bool = ...,
  371. ) -> Any: ...
  372. # TODO: Returns a scalar or array
  373. def multi_dot(
  374. arrays: Iterable[_ArrayLikeComplex_co | _ArrayLikeObject_co | _ArrayLikeTD64_co],
  375. *,
  376. out: None | NDArray[Any] = ...,
  377. ) -> Any: ...
  378. def diagonal(
  379. x: ArrayLike, # >= 2D array
  380. /,
  381. *,
  382. offset: SupportsIndex = ...,
  383. ) -> NDArray[Any]: ...
  384. def trace(
  385. x: ArrayLike, # >= 2D array
  386. /,
  387. *,
  388. offset: SupportsIndex = ...,
  389. dtype: DTypeLike = ...,
  390. ) -> Any: ...
  391. @overload
  392. def cross(
  393. x1: _ArrayLikeUInt_co,
  394. x2: _ArrayLikeUInt_co,
  395. /,
  396. *,
  397. axis: int = ...,
  398. ) -> NDArray[unsignedinteger[Any]]: ...
  399. @overload
  400. def cross(
  401. x1: _ArrayLikeInt_co,
  402. x2: _ArrayLikeInt_co,
  403. /,
  404. *,
  405. axis: int = ...,
  406. ) -> NDArray[signedinteger[Any]]: ...
  407. @overload
  408. def cross(
  409. x1: _ArrayLikeFloat_co,
  410. x2: _ArrayLikeFloat_co,
  411. /,
  412. *,
  413. axis: int = ...,
  414. ) -> NDArray[floating[Any]]: ...
  415. @overload
  416. def cross(
  417. x1: _ArrayLikeComplex_co,
  418. x2: _ArrayLikeComplex_co,
  419. /,
  420. *,
  421. axis: int = ...,
  422. ) -> NDArray[complexfloating[Any, Any]]: ...
  423. @overload
  424. def matmul(
  425. x1: _ArrayLikeInt_co,
  426. x2: _ArrayLikeInt_co,
  427. ) -> NDArray[signedinteger[Any]]: ...
  428. @overload
  429. def matmul(
  430. x1: _ArrayLikeUInt_co,
  431. x2: _ArrayLikeUInt_co,
  432. ) -> NDArray[unsignedinteger[Any]]: ...
  433. @overload
  434. def matmul(
  435. x1: _ArrayLikeFloat_co,
  436. x2: _ArrayLikeFloat_co,
  437. ) -> NDArray[floating[Any]]: ...
  438. @overload
  439. def matmul(
  440. x1: _ArrayLikeComplex_co,
  441. x2: _ArrayLikeComplex_co,
  442. ) -> NDArray[complexfloating[Any, Any]]: ...