_linalg.pyi 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548
  1. from collections.abc import Iterable
  2. from typing import (
  3. Any,
  4. Literal as L,
  5. NamedTuple,
  6. Never,
  7. SupportsIndex,
  8. SupportsInt,
  9. TypeAlias,
  10. TypeVar,
  11. overload,
  12. )
  13. import numpy as np
  14. from numpy import (
  15. complex128,
  16. complexfloating,
  17. float64,
  18. floating,
  19. int32,
  20. object_,
  21. signedinteger,
  22. timedelta64,
  23. unsignedinteger,
  24. vecdot,
  25. )
  26. from numpy._core.fromnumeric import matrix_transpose
  27. from numpy._globals import _NoValueType
  28. from numpy._typing import (
  29. ArrayLike,
  30. DTypeLike,
  31. NDArray,
  32. _ArrayLike,
  33. _ArrayLikeBool_co,
  34. _ArrayLikeComplex_co,
  35. _ArrayLikeFloat_co,
  36. _ArrayLikeInt_co,
  37. _ArrayLikeNumber_co,
  38. _ArrayLikeObject_co,
  39. _ArrayLikeTD64_co,
  40. _ArrayLikeUInt_co,
  41. _NestedSequence,
  42. _ShapeLike,
  43. )
  44. from numpy.linalg import LinAlgError
  45. __all__ = [
  46. "matrix_power",
  47. "solve",
  48. "tensorsolve",
  49. "tensorinv",
  50. "inv",
  51. "cholesky",
  52. "eigvals",
  53. "eigvalsh",
  54. "pinv",
  55. "slogdet",
  56. "det",
  57. "svd",
  58. "svdvals",
  59. "eig",
  60. "eigh",
  61. "lstsq",
  62. "norm",
  63. "qr",
  64. "cond",
  65. "matrix_rank",
  66. "LinAlgError",
  67. "multi_dot",
  68. "trace",
  69. "diagonal",
  70. "cross",
  71. "outer",
  72. "tensordot",
  73. "matmul",
  74. "matrix_transpose",
  75. "matrix_norm",
  76. "vector_norm",
  77. "vecdot",
  78. ]
  79. _NumberT = TypeVar("_NumberT", bound=np.number)
  80. _NumericScalarT = TypeVar("_NumericScalarT", bound=np.number | np.timedelta64 | np.object_)
  81. _ModeKind: TypeAlias = L["reduced", "complete", "r", "raw"]
  82. ###
  83. fortran_int = np.intc
  84. class EigResult(NamedTuple):
  85. eigenvalues: NDArray[Any]
  86. eigenvectors: NDArray[Any]
  87. class EighResult(NamedTuple):
  88. eigenvalues: NDArray[Any]
  89. eigenvectors: NDArray[Any]
  90. class QRResult(NamedTuple):
  91. Q: NDArray[Any]
  92. R: NDArray[Any]
  93. class SlogdetResult(NamedTuple):
  94. # TODO: `sign` and `logabsdet` are scalars for input 2D arrays and
  95. # a `(x.ndim - 2)`` dimensionl arrays otherwise
  96. sign: Any
  97. logabsdet: Any
  98. class SVDResult(NamedTuple):
  99. U: NDArray[Any]
  100. S: NDArray[Any]
  101. Vh: NDArray[Any]
  102. @overload
  103. def tensorsolve(
  104. a: _ArrayLikeInt_co,
  105. b: _ArrayLikeInt_co,
  106. axes: Iterable[int] | None = None,
  107. ) -> NDArray[float64]: ...
  108. @overload
  109. def tensorsolve(
  110. a: _ArrayLikeFloat_co,
  111. b: _ArrayLikeFloat_co,
  112. axes: Iterable[int] | None = None,
  113. ) -> NDArray[floating]: ...
  114. @overload
  115. def tensorsolve(
  116. a: _ArrayLikeComplex_co,
  117. b: _ArrayLikeComplex_co,
  118. axes: Iterable[int] | None = None,
  119. ) -> NDArray[complexfloating]: ...
  120. @overload
  121. def solve(
  122. a: _ArrayLikeInt_co,
  123. b: _ArrayLikeInt_co,
  124. ) -> NDArray[float64]: ...
  125. @overload
  126. def solve(
  127. a: _ArrayLikeFloat_co,
  128. b: _ArrayLikeFloat_co,
  129. ) -> NDArray[floating]: ...
  130. @overload
  131. def solve(
  132. a: _ArrayLikeComplex_co,
  133. b: _ArrayLikeComplex_co,
  134. ) -> NDArray[complexfloating]: ...
  135. @overload
  136. def tensorinv(
  137. a: _ArrayLikeInt_co,
  138. ind: int = 2,
  139. ) -> NDArray[float64]: ...
  140. @overload
  141. def tensorinv(
  142. a: _ArrayLikeFloat_co,
  143. ind: int = 2,
  144. ) -> NDArray[floating]: ...
  145. @overload
  146. def tensorinv(
  147. a: _ArrayLikeComplex_co,
  148. ind: int = 2,
  149. ) -> NDArray[complexfloating]: ...
  150. @overload
  151. def inv(a: _ArrayLikeInt_co) -> NDArray[float64]: ...
  152. @overload
  153. def inv(a: _ArrayLikeFloat_co) -> NDArray[floating]: ...
  154. @overload
  155. def inv(a: _ArrayLikeComplex_co) -> NDArray[complexfloating]: ...
  156. # TODO: The supported input and output dtypes are dependent on the value of `n`.
  157. # For example: `n < 0` always casts integer types to float64
  158. def matrix_power(
  159. a: _ArrayLikeComplex_co | _ArrayLikeObject_co,
  160. n: SupportsIndex,
  161. ) -> NDArray[Any]: ...
  162. @overload
  163. def cholesky(a: _ArrayLikeInt_co, /, *, upper: bool = False) -> NDArray[float64]: ...
  164. @overload
  165. def cholesky(a: _ArrayLikeFloat_co, /, *, upper: bool = False) -> NDArray[floating]: ...
  166. @overload
  167. def cholesky(a: _ArrayLikeComplex_co, /, *, upper: bool = False) -> NDArray[complexfloating]: ...
  168. @overload
  169. def outer(x1: _ArrayLike[Never], x2: _ArrayLike[Never], /) -> NDArray[Any]: ...
  170. @overload
  171. def outer(x1: _ArrayLikeBool_co, x2: _ArrayLikeBool_co, /) -> NDArray[np.bool]: ...
  172. @overload
  173. def outer(x1: _ArrayLike[_NumberT], x2: _ArrayLike[_NumberT], /) -> NDArray[_NumberT]: ...
  174. @overload
  175. def outer(x1: _ArrayLikeUInt_co, x2: _ArrayLikeUInt_co, /) -> NDArray[unsignedinteger]: ...
  176. @overload
  177. def outer(x1: _ArrayLikeInt_co, x2: _ArrayLikeInt_co, /) -> NDArray[signedinteger]: ...
  178. @overload
  179. def outer(x1: _ArrayLikeFloat_co, x2: _ArrayLikeFloat_co, /) -> NDArray[floating]: ...
  180. @overload
  181. def outer(x1: _ArrayLikeComplex_co, x2: _ArrayLikeComplex_co, /) -> NDArray[complexfloating]: ...
  182. @overload
  183. def outer(x1: _ArrayLikeTD64_co, x2: _ArrayLikeTD64_co, /) -> NDArray[timedelta64]: ...
  184. @overload
  185. def outer(x1: _ArrayLikeObject_co, x2: _ArrayLikeObject_co, /) -> NDArray[object_]: ...
  186. @overload
  187. def outer(
  188. x1: _ArrayLikeComplex_co | _ArrayLikeTD64_co | _ArrayLikeObject_co,
  189. x2: _ArrayLikeComplex_co | _ArrayLikeTD64_co | _ArrayLikeObject_co,
  190. /,
  191. ) -> NDArray[Any]: ...
  192. @overload
  193. def qr(a: _ArrayLikeInt_co, mode: _ModeKind = "reduced") -> QRResult: ...
  194. @overload
  195. def qr(a: _ArrayLikeFloat_co, mode: _ModeKind = "reduced") -> QRResult: ...
  196. @overload
  197. def qr(a: _ArrayLikeComplex_co, mode: _ModeKind = "reduced") -> QRResult: ...
  198. @overload
  199. def eigvals(a: _ArrayLikeInt_co) -> NDArray[float64] | NDArray[complex128]: ...
  200. @overload
  201. def eigvals(a: _ArrayLikeFloat_co) -> NDArray[floating] | NDArray[complexfloating]: ...
  202. @overload
  203. def eigvals(a: _ArrayLikeComplex_co) -> NDArray[complexfloating]: ...
  204. @overload
  205. def eigvalsh(a: _ArrayLikeInt_co, UPLO: L["L", "U", "l", "u"] = "L") -> NDArray[float64]: ...
  206. @overload
  207. def eigvalsh(a: _ArrayLikeComplex_co, UPLO: L["L", "U", "l", "u"] = "L") -> NDArray[floating]: ...
  208. @overload
  209. def eig(a: _ArrayLikeInt_co) -> EigResult: ...
  210. @overload
  211. def eig(a: _ArrayLikeFloat_co) -> EigResult: ...
  212. @overload
  213. def eig(a: _ArrayLikeComplex_co) -> EigResult: ...
  214. @overload
  215. def eigh(
  216. a: _ArrayLikeInt_co,
  217. UPLO: L["L", "U", "l", "u"] = "L",
  218. ) -> EighResult: ...
  219. @overload
  220. def eigh(
  221. a: _ArrayLikeFloat_co,
  222. UPLO: L["L", "U", "l", "u"] = "L",
  223. ) -> EighResult: ...
  224. @overload
  225. def eigh(
  226. a: _ArrayLikeComplex_co,
  227. UPLO: L["L", "U", "l", "u"] = "L",
  228. ) -> EighResult: ...
  229. @overload
  230. def svd(
  231. a: _ArrayLikeInt_co,
  232. full_matrices: bool = True,
  233. compute_uv: L[True] = True,
  234. hermitian: bool = False,
  235. ) -> SVDResult: ...
  236. @overload
  237. def svd(
  238. a: _ArrayLikeFloat_co,
  239. full_matrices: bool = True,
  240. compute_uv: L[True] = True,
  241. hermitian: bool = False,
  242. ) -> SVDResult: ...
  243. @overload
  244. def svd(
  245. a: _ArrayLikeComplex_co,
  246. full_matrices: bool = True,
  247. compute_uv: L[True] = True,
  248. hermitian: bool = False,
  249. ) -> SVDResult: ...
  250. @overload
  251. def svd(
  252. a: _ArrayLikeInt_co,
  253. full_matrices: bool = True,
  254. *,
  255. compute_uv: L[False],
  256. hermitian: bool = False,
  257. ) -> NDArray[float64]: ...
  258. @overload
  259. def svd(
  260. a: _ArrayLikeInt_co,
  261. full_matrices: bool,
  262. compute_uv: L[False],
  263. hermitian: bool = False,
  264. ) -> NDArray[float64]: ...
  265. @overload
  266. def svd(
  267. a: _ArrayLikeComplex_co,
  268. full_matrices: bool = True,
  269. *,
  270. compute_uv: L[False],
  271. hermitian: bool = False,
  272. ) -> NDArray[floating]: ...
  273. @overload
  274. def svd(
  275. a: _ArrayLikeComplex_co,
  276. full_matrices: bool,
  277. compute_uv: L[False],
  278. hermitian: bool = False,
  279. ) -> NDArray[floating]: ...
  280. # the ignored `overload-overlap` mypy error below is a false-positive
  281. @overload
  282. def svdvals( # type: ignore[overload-overlap]
  283. x: _ArrayLike[np.float64 | np.complex128 | np.integer | np.bool] | _NestedSequence[complex], /
  284. ) -> NDArray[np.float64]: ...
  285. @overload
  286. def svdvals(x: _ArrayLike[np.float32 | np.complex64], /) -> NDArray[np.float32]: ...
  287. @overload
  288. def svdvals(x: _ArrayLikeNumber_co, /) -> NDArray[floating]: ...
  289. # TODO: Returns a scalar for 2D arrays and
  290. # a `(x.ndim - 2)`` dimensionl array otherwise
  291. def cond(x: _ArrayLikeComplex_co, p: float | L["fro", "nuc"] | None = None) -> Any: ...
  292. # TODO: Returns `int` for <2D arrays and `intp` otherwise
  293. def matrix_rank(
  294. A: _ArrayLikeComplex_co,
  295. tol: _ArrayLikeFloat_co | None = None,
  296. hermitian: bool = False,
  297. *,
  298. rtol: _ArrayLikeFloat_co | None = None,
  299. ) -> Any: ...
  300. @overload
  301. def pinv(
  302. a: _ArrayLikeInt_co,
  303. rcond: _ArrayLikeFloat_co | None = None,
  304. hermitian: bool = False,
  305. *,
  306. rtol: _ArrayLikeFloat_co | _NoValueType = ...,
  307. ) -> NDArray[float64]: ...
  308. @overload
  309. def pinv(
  310. a: _ArrayLikeFloat_co,
  311. rcond: _ArrayLikeFloat_co | None = None,
  312. hermitian: bool = False,
  313. *,
  314. rtol: _ArrayLikeFloat_co | _NoValueType = ...,
  315. ) -> NDArray[floating]: ...
  316. @overload
  317. def pinv(
  318. a: _ArrayLikeComplex_co,
  319. rcond: _ArrayLikeFloat_co | None = None,
  320. hermitian: bool = False,
  321. *,
  322. rtol: _ArrayLikeFloat_co | _NoValueType = ...,
  323. ) -> NDArray[complexfloating]: ...
  324. # TODO: Returns a 2-tuple of scalars for 2D arrays and
  325. # a 2-tuple of `(a.ndim - 2)`` dimensionl arrays otherwise
  326. def slogdet(a: _ArrayLikeComplex_co) -> SlogdetResult: ...
  327. # TODO: Returns a 2-tuple of scalars for 2D arrays and
  328. # a 2-tuple of `(a.ndim - 2)`` dimensionl arrays otherwise
  329. def det(a: _ArrayLikeComplex_co) -> Any: ...
  330. @overload
  331. def lstsq(a: _ArrayLikeInt_co, b: _ArrayLikeInt_co, rcond: float | None = None) -> tuple[
  332. NDArray[float64],
  333. NDArray[float64],
  334. int32,
  335. NDArray[float64],
  336. ]: ...
  337. @overload
  338. def lstsq(a: _ArrayLikeFloat_co, b: _ArrayLikeFloat_co, rcond: float | None = None) -> tuple[
  339. NDArray[floating],
  340. NDArray[floating],
  341. int32,
  342. NDArray[floating],
  343. ]: ...
  344. @overload
  345. def lstsq(a: _ArrayLikeComplex_co, b: _ArrayLikeComplex_co, rcond: float | None = None) -> tuple[
  346. NDArray[complexfloating],
  347. NDArray[floating],
  348. int32,
  349. NDArray[floating],
  350. ]: ...
  351. @overload
  352. def norm(
  353. x: ArrayLike,
  354. ord: float | L["fro", "nuc"] | None = None,
  355. axis: None = None,
  356. keepdims: L[False] = False,
  357. ) -> floating: ...
  358. @overload
  359. def norm(
  360. x: ArrayLike,
  361. ord: float | L["fro", "nuc"] | None,
  362. axis: SupportsInt | SupportsIndex | tuple[int, ...] | None,
  363. keepdims: bool = False,
  364. ) -> Any: ...
  365. @overload
  366. def norm(
  367. x: ArrayLike,
  368. ord: float | L["fro", "nuc"] | None = None,
  369. *,
  370. axis: SupportsInt | SupportsIndex | tuple[int, ...] | None,
  371. keepdims: bool = False,
  372. ) -> Any: ...
  373. @overload
  374. def matrix_norm(
  375. x: ArrayLike,
  376. /,
  377. *,
  378. ord: float | L["fro", "nuc"] | None = "fro",
  379. keepdims: L[False] = False,
  380. ) -> floating: ...
  381. @overload
  382. def matrix_norm(
  383. x: ArrayLike,
  384. /,
  385. *,
  386. ord: float | L["fro", "nuc"] | None = "fro",
  387. keepdims: bool = False,
  388. ) -> Any: ...
  389. @overload
  390. def vector_norm(
  391. x: ArrayLike,
  392. /,
  393. *,
  394. axis: None = None,
  395. ord: float | None = 2,
  396. keepdims: L[False] = False,
  397. ) -> floating: ...
  398. @overload
  399. def vector_norm(
  400. x: ArrayLike,
  401. /,
  402. *,
  403. axis: SupportsInt | SupportsIndex | tuple[int, ...],
  404. ord: float | None = 2,
  405. keepdims: bool = False,
  406. ) -> Any: ...
  407. # keep in sync with numpy._core.numeric.tensordot (ignoring `/, *`)
  408. @overload
  409. def tensordot(
  410. a: _ArrayLike[_NumericScalarT],
  411. b: _ArrayLike[_NumericScalarT],
  412. /,
  413. *,
  414. axes: int | tuple[_ShapeLike, _ShapeLike] = 2,
  415. ) -> NDArray[_NumericScalarT]: ...
  416. @overload
  417. def tensordot(
  418. a: _ArrayLikeBool_co,
  419. b: _ArrayLikeBool_co,
  420. /,
  421. *,
  422. axes: int | tuple[_ShapeLike, _ShapeLike] = 2,
  423. ) -> NDArray[np.bool_]: ...
  424. @overload
  425. def tensordot(
  426. a: _ArrayLikeInt_co,
  427. b: _ArrayLikeInt_co,
  428. /,
  429. *,
  430. axes: int | tuple[_ShapeLike, _ShapeLike] = 2,
  431. ) -> NDArray[np.int_ | Any]: ...
  432. @overload
  433. def tensordot(
  434. a: _ArrayLikeFloat_co,
  435. b: _ArrayLikeFloat_co,
  436. /,
  437. *,
  438. axes: int | tuple[_ShapeLike, _ShapeLike] = 2,
  439. ) -> NDArray[np.float64 | Any]: ...
  440. @overload
  441. def tensordot(
  442. a: _ArrayLikeComplex_co,
  443. b: _ArrayLikeComplex_co,
  444. /,
  445. *,
  446. axes: int | tuple[_ShapeLike, _ShapeLike] = 2,
  447. ) -> NDArray[np.complex128 | Any]: ...
  448. # TODO: Returns a scalar or array
  449. def multi_dot(
  450. arrays: Iterable[_ArrayLikeComplex_co | _ArrayLikeObject_co | _ArrayLikeTD64_co],
  451. *,
  452. out: NDArray[Any] | None = None,
  453. ) -> Any: ...
  454. def diagonal(
  455. x: ArrayLike, # >= 2D array
  456. /,
  457. *,
  458. offset: SupportsIndex = 0,
  459. ) -> NDArray[Any]: ...
  460. def trace(
  461. x: ArrayLike, # >= 2D array
  462. /,
  463. *,
  464. offset: SupportsIndex = 0,
  465. dtype: DTypeLike | None = None,
  466. ) -> Any: ...
  467. @overload
  468. def cross(
  469. x1: _ArrayLikeUInt_co,
  470. x2: _ArrayLikeUInt_co,
  471. /,
  472. *,
  473. axis: int = -1,
  474. ) -> NDArray[unsignedinteger]: ...
  475. @overload
  476. def cross(
  477. x1: _ArrayLikeInt_co,
  478. x2: _ArrayLikeInt_co,
  479. /,
  480. *,
  481. axis: int = -1,
  482. ) -> NDArray[signedinteger]: ...
  483. @overload
  484. def cross(
  485. x1: _ArrayLikeFloat_co,
  486. x2: _ArrayLikeFloat_co,
  487. /,
  488. *,
  489. axis: int = -1,
  490. ) -> NDArray[floating]: ...
  491. @overload
  492. def cross(
  493. x1: _ArrayLikeComplex_co,
  494. x2: _ArrayLikeComplex_co,
  495. /,
  496. *,
  497. axis: int = -1,
  498. ) -> NDArray[complexfloating]: ...
  499. @overload
  500. def matmul(x1: _ArrayLike[_NumberT], x2: _ArrayLike[_NumberT], /) -> NDArray[_NumberT]: ...
  501. @overload
  502. def matmul(x1: _ArrayLikeUInt_co, x2: _ArrayLikeUInt_co, /) -> NDArray[unsignedinteger]: ...
  503. @overload
  504. def matmul(x1: _ArrayLikeInt_co, x2: _ArrayLikeInt_co, /) -> NDArray[signedinteger]: ...
  505. @overload
  506. def matmul(x1: _ArrayLikeFloat_co, x2: _ArrayLikeFloat_co, /) -> NDArray[floating]: ...
  507. @overload
  508. def matmul(x1: _ArrayLikeComplex_co, x2: _ArrayLikeComplex_co, /) -> NDArray[complexfloating]: ...