utils.pyi 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. import ast
  2. import sys
  3. import types
  4. import unittest
  5. import warnings
  6. from collections.abc import Callable, Iterable, Sequence
  7. from contextlib import _GeneratorContextManager
  8. from pathlib import Path
  9. from re import Pattern
  10. from typing import (
  11. Any,
  12. AnyStr,
  13. ClassVar,
  14. Final,
  15. Generic,
  16. NoReturn,
  17. SupportsIndex,
  18. TypeAlias,
  19. overload,
  20. type_check_only,
  21. )
  22. from typing import Literal as L
  23. from unittest.case import SkipTest
  24. from _typeshed import ConvertibleToFloat, GenericPath, StrOrBytesPath, StrPath
  25. from typing_extensions import ParamSpec, Self, TypeVar, TypeVarTuple, Unpack
  26. import numpy as np
  27. from numpy._typing import (
  28. ArrayLike,
  29. DTypeLike,
  30. NDArray,
  31. _ArrayLikeDT64_co,
  32. _ArrayLikeNumber_co,
  33. _ArrayLikeObject_co,
  34. _ArrayLikeTD64_co,
  35. )
  36. __all__ = [ # noqa: RUF022
  37. "IS_EDITABLE",
  38. "IS_MUSL",
  39. "IS_PYPY",
  40. "IS_PYSTON",
  41. "IS_WASM",
  42. "HAS_LAPACK64",
  43. "HAS_REFCOUNT",
  44. "NOGIL_BUILD",
  45. "assert_",
  46. "assert_array_almost_equal_nulp",
  47. "assert_raises_regex",
  48. "assert_array_max_ulp",
  49. "assert_warns",
  50. "assert_no_warnings",
  51. "assert_allclose",
  52. "assert_equal",
  53. "assert_almost_equal",
  54. "assert_approx_equal",
  55. "assert_array_equal",
  56. "assert_array_less",
  57. "assert_string_equal",
  58. "assert_array_almost_equal",
  59. "assert_raises",
  60. "build_err_msg",
  61. "decorate_methods",
  62. "jiffies",
  63. "memusage",
  64. "print_assert_equal",
  65. "rundocs",
  66. "runstring",
  67. "verbose",
  68. "measure",
  69. "IgnoreException",
  70. "clear_and_catch_warnings",
  71. "SkipTest",
  72. "KnownFailureException",
  73. "temppath",
  74. "tempdir",
  75. "suppress_warnings",
  76. "assert_array_compare",
  77. "assert_no_gc_cycles",
  78. "break_cycles",
  79. "check_support_sve",
  80. "run_threaded",
  81. ]
  82. ###
  83. _T = TypeVar("_T")
  84. _Ts = TypeVarTuple("_Ts")
  85. _Tss = ParamSpec("_Tss")
  86. _ET = TypeVar("_ET", bound=BaseException, default=BaseException)
  87. _FT = TypeVar("_FT", bound=Callable[..., Any])
  88. _W_co = TypeVar("_W_co", bound=_WarnLog | None, default=_WarnLog | None, covariant=True)
  89. _T_or_bool = TypeVar("_T_or_bool", default=bool)
  90. _StrLike: TypeAlias = str | bytes
  91. _RegexLike: TypeAlias = _StrLike | Pattern[Any]
  92. _NumericArrayLike: TypeAlias = _ArrayLikeNumber_co | _ArrayLikeObject_co
  93. _ExceptionSpec: TypeAlias = type[_ET] | tuple[type[_ET], ...]
  94. _WarningSpec: TypeAlias = type[Warning]
  95. _WarnLog: TypeAlias = list[warnings.WarningMessage]
  96. _ToModules: TypeAlias = Iterable[types.ModuleType]
  97. # Must return a bool or an ndarray/generic type that is supported by `np.logical_and.reduce`
  98. _ComparisonFunc: TypeAlias = Callable[
  99. [NDArray[Any], NDArray[Any]],
  100. bool | np.bool | np.number | NDArray[np.bool | np.number | np.object_],
  101. ]
  102. # Type-check only `clear_and_catch_warnings` subclasses for both values of the
  103. # `record` parameter. Copied from the stdlib `warnings` stubs.
  104. @type_check_only
  105. class _clear_and_catch_warnings_with_records(clear_and_catch_warnings):
  106. def __enter__(self) -> list[warnings.WarningMessage]: ...
  107. @type_check_only
  108. class _clear_and_catch_warnings_without_records(clear_and_catch_warnings):
  109. def __enter__(self) -> None: ...
  110. ###
  111. verbose: int = 0
  112. NUMPY_ROOT: Final[Path] = ...
  113. IS_INSTALLED: Final[bool] = ...
  114. IS_EDITABLE: Final[bool] = ...
  115. IS_MUSL: Final[bool] = ...
  116. IS_PYPY: Final[bool] = ...
  117. IS_PYSTON: Final[bool] = ...
  118. IS_WASM: Final[bool] = ...
  119. HAS_REFCOUNT: Final[bool] = ...
  120. HAS_LAPACK64: Final[bool] = ...
  121. NOGIL_BUILD: Final[bool] = ...
  122. class KnownFailureException(Exception): ...
  123. class IgnoreException(Exception): ...
  124. # NOTE: `warnings.catch_warnings` is incorrectly defined as invariant in typeshed
  125. class clear_and_catch_warnings(warnings.catch_warnings[_W_co], Generic[_W_co]): # type: ignore[type-var] # pyright: ignore[reportInvalidTypeArguments]
  126. class_modules: ClassVar[tuple[types.ModuleType, ...]] = ()
  127. modules: Final[set[types.ModuleType]]
  128. @overload # record: True
  129. def __init__(self: clear_and_catch_warnings[_WarnLog], /, record: L[True], modules: _ToModules = ()) -> None: ...
  130. @overload # record: False (default)
  131. def __init__(self: clear_and_catch_warnings[None], /, record: L[False] = False, modules: _ToModules = ()) -> None: ...
  132. @overload # record; bool
  133. def __init__(self, /, record: bool, modules: _ToModules = ()) -> None: ...
  134. class suppress_warnings:
  135. log: Final[_WarnLog]
  136. def __init__(self, /, forwarding_rule: L["always", "module", "once", "location"] = "always") -> None: ...
  137. def __enter__(self) -> Self: ...
  138. def __exit__(self, cls: type[BaseException] | None, exc: BaseException | None, tb: types.TracebackType | None, /) -> None: ...
  139. def __call__(self, /, func: _FT) -> _FT: ...
  140. #
  141. def filter(self, /, category: type[Warning] = ..., message: str = "", module: types.ModuleType | None = None) -> None: ...
  142. def record(self, /, category: type[Warning] = ..., message: str = "", module: types.ModuleType | None = None) -> _WarnLog: ...
  143. # Contrary to runtime we can't do `os.name` checks while type checking,
  144. # only `sys.platform` checks
  145. if sys.platform == "win32" or sys.platform == "cygwin":
  146. def memusage(processName: str = ..., instance: int = ...) -> int: ...
  147. elif sys.platform == "linux":
  148. def memusage(_proc_pid_stat: StrOrBytesPath = ...) -> int | None: ...
  149. else:
  150. def memusage() -> NoReturn: ...
  151. if sys.platform == "linux":
  152. def jiffies(_proc_pid_stat: StrOrBytesPath = ..., _load_time: list[float] = []) -> int: ...
  153. else:
  154. def jiffies(_load_time: list[float] = []) -> int: ...
  155. #
  156. def build_err_msg(
  157. arrays: Iterable[object],
  158. err_msg: object,
  159. header: str = ...,
  160. verbose: bool = ...,
  161. names: Sequence[str] = ...,
  162. precision: SupportsIndex | None = ...,
  163. ) -> str: ...
  164. #
  165. def print_assert_equal(test_string: str, actual: object, desired: object) -> None: ...
  166. #
  167. def assert_(val: object, msg: str | Callable[[], str] = "") -> None: ...
  168. #
  169. def assert_equal(
  170. actual: object,
  171. desired: object,
  172. err_msg: object = "",
  173. verbose: bool = True,
  174. *,
  175. strict: bool = False,
  176. ) -> None: ...
  177. def assert_almost_equal(
  178. actual: _NumericArrayLike,
  179. desired: _NumericArrayLike,
  180. decimal: int = 7,
  181. err_msg: object = "",
  182. verbose: bool = True,
  183. ) -> None: ...
  184. #
  185. def assert_approx_equal(
  186. actual: ConvertibleToFloat,
  187. desired: ConvertibleToFloat,
  188. significant: int = 7,
  189. err_msg: object = "",
  190. verbose: bool = True,
  191. ) -> None: ...
  192. #
  193. def assert_array_compare(
  194. comparison: _ComparisonFunc,
  195. x: ArrayLike,
  196. y: ArrayLike,
  197. err_msg: object = "",
  198. verbose: bool = True,
  199. header: str = "",
  200. precision: SupportsIndex = 6,
  201. equal_nan: bool = True,
  202. equal_inf: bool = True,
  203. *,
  204. strict: bool = False,
  205. names: tuple[str, str] = ("ACTUAL", "DESIRED"),
  206. ) -> None: ...
  207. #
  208. def assert_array_equal(
  209. actual: object,
  210. desired: object,
  211. err_msg: object = "",
  212. verbose: bool = True,
  213. *,
  214. strict: bool = False,
  215. ) -> None: ...
  216. #
  217. def assert_array_almost_equal(
  218. actual: _NumericArrayLike,
  219. desired: _NumericArrayLike,
  220. decimal: float = 6,
  221. err_msg: object = "",
  222. verbose: bool = True,
  223. ) -> None: ...
  224. @overload
  225. def assert_array_less(
  226. x: _ArrayLikeDT64_co,
  227. y: _ArrayLikeDT64_co,
  228. err_msg: object = "",
  229. verbose: bool = True,
  230. *,
  231. strict: bool = False,
  232. ) -> None: ...
  233. @overload
  234. def assert_array_less(
  235. x: _ArrayLikeTD64_co,
  236. y: _ArrayLikeTD64_co,
  237. err_msg: object = "",
  238. verbose: bool = True,
  239. *,
  240. strict: bool = False,
  241. ) -> None: ...
  242. @overload
  243. def assert_array_less(
  244. x: _NumericArrayLike,
  245. y: _NumericArrayLike,
  246. err_msg: object = "",
  247. verbose: bool = True,
  248. *,
  249. strict: bool = False,
  250. ) -> None: ...
  251. #
  252. def assert_string_equal(actual: str, desired: str) -> None: ...
  253. #
  254. @overload
  255. def assert_raises(
  256. exception_class: _ExceptionSpec[_ET],
  257. /,
  258. *,
  259. msg: str | None = None,
  260. ) -> unittest.case._AssertRaisesContext[_ET]: ...
  261. @overload
  262. def assert_raises(
  263. exception_class: _ExceptionSpec,
  264. callable: Callable[_Tss, Any],
  265. /,
  266. *args: _Tss.args,
  267. **kwargs: _Tss.kwargs,
  268. ) -> None: ...
  269. #
  270. @overload
  271. def assert_raises_regex(
  272. exception_class: _ExceptionSpec[_ET],
  273. expected_regexp: _RegexLike,
  274. *,
  275. msg: str | None = None,
  276. ) -> unittest.case._AssertRaisesContext[_ET]: ...
  277. @overload
  278. def assert_raises_regex(
  279. exception_class: _ExceptionSpec,
  280. expected_regexp: _RegexLike,
  281. callable: Callable[_Tss, Any],
  282. *args: _Tss.args,
  283. **kwargs: _Tss.kwargs,
  284. ) -> None: ...
  285. #
  286. @overload
  287. def assert_allclose(
  288. actual: _ArrayLikeTD64_co,
  289. desired: _ArrayLikeTD64_co,
  290. rtol: float = 1e-7,
  291. atol: float = 0,
  292. equal_nan: bool = True,
  293. err_msg: object = "",
  294. verbose: bool = True,
  295. *,
  296. strict: bool = False,
  297. ) -> None: ...
  298. @overload
  299. def assert_allclose(
  300. actual: _NumericArrayLike,
  301. desired: _NumericArrayLike,
  302. rtol: float = 1e-7,
  303. atol: float = 0,
  304. equal_nan: bool = True,
  305. err_msg: object = "",
  306. verbose: bool = True,
  307. *,
  308. strict: bool = False,
  309. ) -> None: ...
  310. #
  311. def assert_array_almost_equal_nulp(
  312. x: _ArrayLikeNumber_co,
  313. y: _ArrayLikeNumber_co,
  314. nulp: float = 1,
  315. ) -> None: ...
  316. #
  317. def assert_array_max_ulp(
  318. a: _ArrayLikeNumber_co,
  319. b: _ArrayLikeNumber_co,
  320. maxulp: float = 1,
  321. dtype: DTypeLike | None = None,
  322. ) -> NDArray[Any]: ...
  323. #
  324. @overload
  325. def assert_warns(warning_class: _WarningSpec) -> _GeneratorContextManager[None]: ...
  326. @overload
  327. def assert_warns(warning_class: _WarningSpec, func: Callable[_Tss, _T], *args: _Tss.args, **kwargs: _Tss.kwargs) -> _T: ...
  328. #
  329. @overload
  330. def assert_no_warnings() -> _GeneratorContextManager[None]: ...
  331. @overload
  332. def assert_no_warnings(func: Callable[_Tss, _T], /, *args: _Tss.args, **kwargs: _Tss.kwargs) -> _T: ...
  333. #
  334. @overload
  335. def assert_no_gc_cycles() -> _GeneratorContextManager[None]: ...
  336. @overload
  337. def assert_no_gc_cycles(func: Callable[_Tss, Any], /, *args: _Tss.args, **kwargs: _Tss.kwargs) -> None: ...
  338. ###
  339. #
  340. @overload
  341. def tempdir(
  342. suffix: None = None,
  343. prefix: None = None,
  344. dir: None = None,
  345. ) -> _GeneratorContextManager[str]: ...
  346. @overload
  347. def tempdir(
  348. suffix: AnyStr | None = None,
  349. prefix: AnyStr | None = None,
  350. *,
  351. dir: GenericPath[AnyStr],
  352. ) -> _GeneratorContextManager[AnyStr]: ...
  353. @overload
  354. def tempdir(
  355. suffix: AnyStr | None = None,
  356. *,
  357. prefix: AnyStr,
  358. dir: GenericPath[AnyStr] | None = None,
  359. ) -> _GeneratorContextManager[AnyStr]: ...
  360. @overload
  361. def tempdir(
  362. suffix: AnyStr,
  363. prefix: AnyStr | None = None,
  364. dir: GenericPath[AnyStr] | None = None,
  365. ) -> _GeneratorContextManager[AnyStr]: ...
  366. #
  367. @overload
  368. def temppath(
  369. suffix: None = None,
  370. prefix: None = None,
  371. dir: None = None,
  372. text: bool = False,
  373. ) -> _GeneratorContextManager[str]: ...
  374. @overload
  375. def temppath(
  376. suffix: AnyStr | None,
  377. prefix: AnyStr | None,
  378. dir: GenericPath[AnyStr],
  379. text: bool = False,
  380. ) -> _GeneratorContextManager[AnyStr]: ...
  381. @overload
  382. def temppath(
  383. suffix: AnyStr | None = None,
  384. prefix: AnyStr | None = None,
  385. *,
  386. dir: GenericPath[AnyStr],
  387. text: bool = False,
  388. ) -> _GeneratorContextManager[AnyStr]: ...
  389. @overload
  390. def temppath(
  391. suffix: AnyStr | None,
  392. prefix: AnyStr,
  393. dir: GenericPath[AnyStr] | None = None,
  394. text: bool = False,
  395. ) -> _GeneratorContextManager[AnyStr]: ...
  396. @overload
  397. def temppath(
  398. suffix: AnyStr | None = None,
  399. *,
  400. prefix: AnyStr,
  401. dir: GenericPath[AnyStr] | None = None,
  402. text: bool = False,
  403. ) -> _GeneratorContextManager[AnyStr]: ...
  404. @overload
  405. def temppath(
  406. suffix: AnyStr,
  407. prefix: AnyStr | None = None,
  408. dir: GenericPath[AnyStr] | None = None,
  409. text: bool = False,
  410. ) -> _GeneratorContextManager[AnyStr]: ...
  411. #
  412. def check_support_sve(__cache: list[_T_or_bool] = []) -> _T_or_bool: ... # noqa: PYI063
  413. #
  414. def decorate_methods(
  415. cls: type,
  416. decorator: Callable[[Callable[..., Any]], Any],
  417. testmatch: _RegexLike | None = None,
  418. ) -> None: ...
  419. #
  420. @overload
  421. def run_threaded(
  422. func: Callable[[], None],
  423. max_workers: int = 8,
  424. pass_count: bool = False,
  425. pass_barrier: bool = False,
  426. outer_iterations: int = 1,
  427. prepare_args: None = None,
  428. ) -> None: ...
  429. @overload
  430. def run_threaded(
  431. func: Callable[[Unpack[_Ts]], None],
  432. max_workers: int,
  433. pass_count: bool,
  434. pass_barrier: bool,
  435. outer_iterations: int,
  436. prepare_args: tuple[Unpack[_Ts]],
  437. ) -> None: ...
  438. @overload
  439. def run_threaded(
  440. func: Callable[[Unpack[_Ts]], None],
  441. max_workers: int = 8,
  442. pass_count: bool = False,
  443. pass_barrier: bool = False,
  444. outer_iterations: int = 1,
  445. *,
  446. prepare_args: tuple[Unpack[_Ts]],
  447. ) -> None: ...
  448. #
  449. def runstring(astr: _StrLike | types.CodeType, dict: dict[str, Any] | None) -> Any: ... # noqa: ANN401
  450. def rundocs(filename: StrPath | None = None, raise_on_error: bool = True) -> None: ...
  451. def measure(code_str: _StrLike | ast.AST, times: int = 1, label: str | None = None) -> float: ...
  452. def break_cycles() -> None: ...