einsumfunc.pyi 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. from collections.abc import Sequence
  2. from typing import TypeAlias, TypeVar, Any, overload, Literal
  3. import numpy as np
  4. from numpy import number, _OrderKACF
  5. from numpy._typing import (
  6. NDArray,
  7. _ArrayLikeBool_co,
  8. _ArrayLikeUInt_co,
  9. _ArrayLikeInt_co,
  10. _ArrayLikeFloat_co,
  11. _ArrayLikeComplex_co,
  12. _ArrayLikeObject_co,
  13. _DTypeLikeBool,
  14. _DTypeLikeUInt,
  15. _DTypeLikeInt,
  16. _DTypeLikeFloat,
  17. _DTypeLikeComplex,
  18. _DTypeLikeComplex_co,
  19. _DTypeLikeObject,
  20. )
  21. __all__ = ["einsum", "einsum_path"]
  22. _ArrayType = TypeVar(
  23. "_ArrayType",
  24. bound=NDArray[np.bool | number[Any]],
  25. )
  26. _OptimizeKind: TypeAlias = bool | Literal["greedy", "optimal"] | Sequence[Any] | None
  27. _CastingSafe: TypeAlias = Literal["no", "equiv", "safe", "same_kind"]
  28. _CastingUnsafe: TypeAlias = Literal["unsafe"]
  29. # TODO: Properly handle the `casting`-based combinatorics
  30. # TODO: We need to evaluate the content `__subscripts` in order
  31. # to identify whether or an array or scalar is returned. At a cursory
  32. # glance this seems like something that can quite easily be done with
  33. # a mypy plugin.
  34. # Something like `is_scalar = bool(__subscripts.partition("->")[-1])`
  35. @overload
  36. def einsum(
  37. subscripts: str | _ArrayLikeInt_co,
  38. /,
  39. *operands: _ArrayLikeBool_co,
  40. out: None = ...,
  41. dtype: None | _DTypeLikeBool = ...,
  42. order: _OrderKACF = ...,
  43. casting: _CastingSafe = ...,
  44. optimize: _OptimizeKind = ...,
  45. ) -> Any: ...
  46. @overload
  47. def einsum(
  48. subscripts: str | _ArrayLikeInt_co,
  49. /,
  50. *operands: _ArrayLikeUInt_co,
  51. out: None = ...,
  52. dtype: None | _DTypeLikeUInt = ...,
  53. order: _OrderKACF = ...,
  54. casting: _CastingSafe = ...,
  55. optimize: _OptimizeKind = ...,
  56. ) -> Any: ...
  57. @overload
  58. def einsum(
  59. subscripts: str | _ArrayLikeInt_co,
  60. /,
  61. *operands: _ArrayLikeInt_co,
  62. out: None = ...,
  63. dtype: None | _DTypeLikeInt = ...,
  64. order: _OrderKACF = ...,
  65. casting: _CastingSafe = ...,
  66. optimize: _OptimizeKind = ...,
  67. ) -> Any: ...
  68. @overload
  69. def einsum(
  70. subscripts: str | _ArrayLikeInt_co,
  71. /,
  72. *operands: _ArrayLikeFloat_co,
  73. out: None = ...,
  74. dtype: None | _DTypeLikeFloat = ...,
  75. order: _OrderKACF = ...,
  76. casting: _CastingSafe = ...,
  77. optimize: _OptimizeKind = ...,
  78. ) -> Any: ...
  79. @overload
  80. def einsum(
  81. subscripts: str | _ArrayLikeInt_co,
  82. /,
  83. *operands: _ArrayLikeComplex_co,
  84. out: None = ...,
  85. dtype: None | _DTypeLikeComplex = ...,
  86. order: _OrderKACF = ...,
  87. casting: _CastingSafe = ...,
  88. optimize: _OptimizeKind = ...,
  89. ) -> Any: ...
  90. @overload
  91. def einsum(
  92. subscripts: str | _ArrayLikeInt_co,
  93. /,
  94. *operands: Any,
  95. casting: _CastingUnsafe,
  96. dtype: None | _DTypeLikeComplex_co = ...,
  97. out: None = ...,
  98. order: _OrderKACF = ...,
  99. optimize: _OptimizeKind = ...,
  100. ) -> Any: ...
  101. @overload
  102. def einsum(
  103. subscripts: str | _ArrayLikeInt_co,
  104. /,
  105. *operands: _ArrayLikeComplex_co,
  106. out: _ArrayType,
  107. dtype: None | _DTypeLikeComplex_co = ...,
  108. order: _OrderKACF = ...,
  109. casting: _CastingSafe = ...,
  110. optimize: _OptimizeKind = ...,
  111. ) -> _ArrayType: ...
  112. @overload
  113. def einsum(
  114. subscripts: str | _ArrayLikeInt_co,
  115. /,
  116. *operands: Any,
  117. out: _ArrayType,
  118. casting: _CastingUnsafe,
  119. dtype: None | _DTypeLikeComplex_co = ...,
  120. order: _OrderKACF = ...,
  121. optimize: _OptimizeKind = ...,
  122. ) -> _ArrayType: ...
  123. @overload
  124. def einsum(
  125. subscripts: str | _ArrayLikeInt_co,
  126. /,
  127. *operands: _ArrayLikeObject_co,
  128. out: None = ...,
  129. dtype: None | _DTypeLikeObject = ...,
  130. order: _OrderKACF = ...,
  131. casting: _CastingSafe = ...,
  132. optimize: _OptimizeKind = ...,
  133. ) -> Any: ...
  134. @overload
  135. def einsum(
  136. subscripts: str | _ArrayLikeInt_co,
  137. /,
  138. *operands: Any,
  139. casting: _CastingUnsafe,
  140. dtype: None | _DTypeLikeObject = ...,
  141. out: None = ...,
  142. order: _OrderKACF = ...,
  143. optimize: _OptimizeKind = ...,
  144. ) -> Any: ...
  145. @overload
  146. def einsum(
  147. subscripts: str | _ArrayLikeInt_co,
  148. /,
  149. *operands: _ArrayLikeObject_co,
  150. out: _ArrayType,
  151. dtype: None | _DTypeLikeObject = ...,
  152. order: _OrderKACF = ...,
  153. casting: _CastingSafe = ...,
  154. optimize: _OptimizeKind = ...,
  155. ) -> _ArrayType: ...
  156. @overload
  157. def einsum(
  158. subscripts: str | _ArrayLikeInt_co,
  159. /,
  160. *operands: Any,
  161. out: _ArrayType,
  162. casting: _CastingUnsafe,
  163. dtype: None | _DTypeLikeObject = ...,
  164. order: _OrderKACF = ...,
  165. optimize: _OptimizeKind = ...,
  166. ) -> _ArrayType: ...
  167. # NOTE: `einsum_call` is a hidden kwarg unavailable for public use.
  168. # It is therefore excluded from the signatures below.
  169. # NOTE: In practice the list consists of a `str` (first element)
  170. # and a variable number of integer tuples.
  171. def einsum_path(
  172. subscripts: str | _ArrayLikeInt_co,
  173. /,
  174. *operands: _ArrayLikeComplex_co | _DTypeLikeObject,
  175. optimize: _OptimizeKind = "greedy",
  176. einsum_call: Literal[False] = False,
  177. ) -> tuple[list[Any], str]: ...