pytest.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. """py.test hacks to support XFAIL/XPASS"""
  2. import platform
  3. import sys
  4. import re
  5. import functools
  6. import os
  7. import contextlib
  8. import warnings
  9. import inspect
  10. import pathlib
  11. from typing import Any, Callable
  12. from sympy.utilities.exceptions import SymPyDeprecationWarning
  13. # Imported here for backwards compatibility. Note: do not import this from
  14. # here in library code (importing sympy.pytest in library code will break the
  15. # pytest integration).
  16. from sympy.utilities.exceptions import ignore_warnings # noqa:F401
  17. ON_CI = os.getenv('CI', None) == "true"
  18. try:
  19. import pytest
  20. USE_PYTEST = getattr(sys, '_running_pytest', False)
  21. except ImportError:
  22. USE_PYTEST = False
  23. IS_WASM: bool = sys.platform == 'emscripten' or platform.machine() in ["wasm32", "wasm64"]
  24. raises: Callable[[Any, Any], Any]
  25. XFAIL: Callable[[Any], Any]
  26. skip: Callable[[Any], Any]
  27. SKIP: Callable[[Any], Any]
  28. slow: Callable[[Any], Any]
  29. tooslow: Callable[[Any], Any]
  30. nocache_fail: Callable[[Any], Any]
  31. if USE_PYTEST:
  32. raises = pytest.raises
  33. skip = pytest.skip
  34. XFAIL = pytest.mark.xfail
  35. SKIP = pytest.mark.skip
  36. slow = pytest.mark.slow
  37. tooslow = pytest.mark.tooslow
  38. nocache_fail = pytest.mark.nocache_fail
  39. from _pytest.outcomes import Failed
  40. else:
  41. # Not using pytest so define the things that would have been imported from
  42. # there.
  43. # _pytest._code.code.ExceptionInfo
  44. class ExceptionInfo:
  45. def __init__(self, value):
  46. self.value = value
  47. def __repr__(self):
  48. return "<ExceptionInfo {!r}>".format(self.value)
  49. def raises(expectedException, code=None):
  50. """
  51. Tests that ``code`` raises the exception ``expectedException``.
  52. ``code`` may be a callable, such as a lambda expression or function
  53. name.
  54. If ``code`` is not given or None, ``raises`` will return a context
  55. manager for use in ``with`` statements; the code to execute then
  56. comes from the scope of the ``with``.
  57. ``raises()`` does nothing if the callable raises the expected exception,
  58. otherwise it raises an AssertionError.
  59. Examples
  60. ========
  61. >>> from sympy.testing.pytest import raises
  62. >>> raises(ZeroDivisionError, lambda: 1/0)
  63. <ExceptionInfo ZeroDivisionError(...)>
  64. >>> raises(ZeroDivisionError, lambda: 1/2)
  65. Traceback (most recent call last):
  66. ...
  67. Failed: DID NOT RAISE
  68. >>> with raises(ZeroDivisionError):
  69. ... n = 1/0
  70. >>> with raises(ZeroDivisionError):
  71. ... n = 1/2
  72. Traceback (most recent call last):
  73. ...
  74. Failed: DID NOT RAISE
  75. Note that you cannot test multiple statements via
  76. ``with raises``:
  77. >>> with raises(ZeroDivisionError):
  78. ... n = 1/0 # will execute and raise, aborting the ``with``
  79. ... n = 9999/0 # never executed
  80. This is just what ``with`` is supposed to do: abort the
  81. contained statement sequence at the first exception and let
  82. the context manager deal with the exception.
  83. To test multiple statements, you'll need a separate ``with``
  84. for each:
  85. >>> with raises(ZeroDivisionError):
  86. ... n = 1/0 # will execute and raise
  87. >>> with raises(ZeroDivisionError):
  88. ... n = 9999/0 # will also execute and raise
  89. """
  90. if code is None:
  91. return RaisesContext(expectedException)
  92. elif callable(code):
  93. try:
  94. code()
  95. except expectedException as e:
  96. return ExceptionInfo(e)
  97. raise Failed("DID NOT RAISE")
  98. elif isinstance(code, str):
  99. raise TypeError(
  100. '\'raises(xxx, "code")\' has been phased out; '
  101. 'change \'raises(xxx, "expression")\' '
  102. 'to \'raises(xxx, lambda: expression)\', '
  103. '\'raises(xxx, "statement")\' '
  104. 'to \'with raises(xxx): statement\'')
  105. else:
  106. raise TypeError(
  107. 'raises() expects a callable for the 2nd argument.')
  108. class RaisesContext:
  109. def __init__(self, expectedException):
  110. self.expectedException = expectedException
  111. def __enter__(self):
  112. return None
  113. def __exit__(self, exc_type, exc_value, traceback):
  114. if exc_type is None:
  115. raise Failed("DID NOT RAISE")
  116. return issubclass(exc_type, self.expectedException)
  117. class XFail(Exception):
  118. pass
  119. class XPass(Exception):
  120. pass
  121. class Skipped(Exception):
  122. pass
  123. class Failed(Exception): # type: ignore
  124. pass
  125. def XFAIL(func):
  126. def wrapper():
  127. try:
  128. func()
  129. except Exception as e:
  130. message = str(e)
  131. if message != "Timeout":
  132. raise XFail(func.__name__)
  133. else:
  134. raise Skipped("Timeout")
  135. raise XPass(func.__name__)
  136. wrapper = functools.update_wrapper(wrapper, func)
  137. return wrapper
  138. def skip(str):
  139. raise Skipped(str)
  140. def SKIP(reason):
  141. """Similar to ``skip()``, but this is a decorator. """
  142. def wrapper(func):
  143. def func_wrapper():
  144. raise Skipped(reason)
  145. func_wrapper = functools.update_wrapper(func_wrapper, func)
  146. return func_wrapper
  147. return wrapper
  148. def slow(func):
  149. func._slow = True
  150. def func_wrapper():
  151. func()
  152. func_wrapper = functools.update_wrapper(func_wrapper, func)
  153. func_wrapper.__wrapped__ = func
  154. return func_wrapper
  155. def tooslow(func):
  156. func._slow = True
  157. func._tooslow = True
  158. def func_wrapper():
  159. skip("Too slow")
  160. func_wrapper = functools.update_wrapper(func_wrapper, func)
  161. func_wrapper.__wrapped__ = func
  162. return func_wrapper
  163. def nocache_fail(func):
  164. "Dummy decorator for marking tests that fail when cache is disabled"
  165. return func
  166. @contextlib.contextmanager
  167. def warns(warningcls, *, match='', test_stacklevel=True):
  168. '''
  169. Like raises but tests that warnings are emitted.
  170. >>> from sympy.testing.pytest import warns
  171. >>> import warnings
  172. >>> with warns(UserWarning):
  173. ... warnings.warn('deprecated', UserWarning, stacklevel=2)
  174. >>> with warns(UserWarning):
  175. ... pass
  176. Traceback (most recent call last):
  177. ...
  178. Failed: DID NOT WARN. No warnings of type UserWarning\
  179. was emitted. The list of emitted warnings is: [].
  180. ``test_stacklevel`` makes it check that the ``stacklevel`` parameter to
  181. ``warn()`` is set so that the warning shows the user line of code (the
  182. code under the warns() context manager). Set this to False if this is
  183. ambiguous or if the context manager does not test the direct user code
  184. that emits the warning.
  185. If the warning is a ``SymPyDeprecationWarning``, this additionally tests
  186. that the ``active_deprecations_target`` is a real target in the
  187. ``active-deprecations.md`` file.
  188. '''
  189. # Absorbs all warnings in warnrec
  190. with warnings.catch_warnings(record=True) as warnrec:
  191. # Any warning other than the one we are looking for is an error
  192. warnings.simplefilter("error")
  193. warnings.filterwarnings("always", category=warningcls)
  194. # Now run the test
  195. yield warnrec
  196. # Raise if expected warning not found
  197. if not any(issubclass(w.category, warningcls) for w in warnrec):
  198. msg = ('Failed: DID NOT WARN.'
  199. ' No warnings of type %s was emitted.'
  200. ' The list of emitted warnings is: %s.'
  201. ) % (warningcls, [w.message for w in warnrec])
  202. raise Failed(msg)
  203. # We don't include the match in the filter above because it would then
  204. # fall to the error filter, so we instead manually check that it matches
  205. # here
  206. for w in warnrec:
  207. # Should always be true due to the filters above
  208. assert issubclass(w.category, warningcls)
  209. if not re.compile(match, re.IGNORECASE).match(str(w.message)):
  210. raise Failed(f"Failed: WRONG MESSAGE. A warning with of the correct category ({warningcls.__name__}) was issued, but it did not match the given match regex ({match!r})")
  211. if test_stacklevel:
  212. for f in inspect.stack():
  213. thisfile = f.filename
  214. file = os.path.split(thisfile)[1]
  215. if file.startswith('test_'):
  216. break
  217. elif file == 'doctest.py':
  218. # skip the stacklevel testing in the doctests of this
  219. # function
  220. return
  221. else:
  222. raise RuntimeError("Could not find the file for the given warning to test the stacklevel")
  223. for w in warnrec:
  224. if w.filename != thisfile:
  225. msg = f'''\
  226. Failed: Warning has the wrong stacklevel. The warning stacklevel needs to be
  227. set so that the line of code shown in the warning message is user code that
  228. calls the deprecated code (the current stacklevel is showing code from
  229. {w.filename} (line {w.lineno}), expected {thisfile})'''.replace('\n', ' ')
  230. raise Failed(msg)
  231. if warningcls == SymPyDeprecationWarning:
  232. this_file = pathlib.Path(__file__)
  233. active_deprecations_file = (this_file.parent.parent.parent / 'doc' /
  234. 'src' / 'explanation' /
  235. 'active-deprecations.md')
  236. if not active_deprecations_file.exists():
  237. # We can only test that the active_deprecations_target works if we are
  238. # in the git repo.
  239. return
  240. targets = []
  241. for w in warnrec:
  242. targets.append(w.message.active_deprecations_target)
  243. text = pathlib.Path(active_deprecations_file).read_text(encoding="utf-8")
  244. for target in targets:
  245. if f'({target})=' not in text:
  246. raise Failed(f"The active deprecations target {target!r} does not appear to be a valid target in the active-deprecations.md file ({active_deprecations_file}).")
  247. def _both_exp_pow(func):
  248. """
  249. Decorator used to run the test twice: the first time `e^x` is represented
  250. as ``Pow(E, x)``, the second time as ``exp(x)`` (exponential object is not
  251. a power).
  252. This is a temporary trick helping to manage the elimination of the class
  253. ``exp`` in favor of a replacement by ``Pow(E, ...)``.
  254. """
  255. from sympy.core.parameters import _exp_is_pow
  256. def func_wrap():
  257. with _exp_is_pow(True):
  258. func()
  259. with _exp_is_pow(False):
  260. func()
  261. wrapper = functools.update_wrapper(func_wrap, func)
  262. return wrapper
  263. @contextlib.contextmanager
  264. def warns_deprecated_sympy():
  265. '''
  266. Shorthand for ``warns(SymPyDeprecationWarning)``
  267. This is the recommended way to test that ``SymPyDeprecationWarning`` is
  268. emitted for deprecated features in SymPy. To test for other warnings use
  269. ``warns``. To suppress warnings without asserting that they are emitted
  270. use ``ignore_warnings``.
  271. .. note::
  272. ``warns_deprecated_sympy()`` is only intended for internal use in the
  273. SymPy test suite to test that a deprecation warning triggers properly.
  274. All other code in the SymPy codebase, including documentation examples,
  275. should not use deprecated behavior.
  276. If you are a user of SymPy and you want to disable
  277. SymPyDeprecationWarnings, use ``warnings`` filters (see
  278. :ref:`silencing-sympy-deprecation-warnings`).
  279. >>> from sympy.testing.pytest import warns_deprecated_sympy
  280. >>> from sympy.utilities.exceptions import sympy_deprecation_warning
  281. >>> with warns_deprecated_sympy():
  282. ... sympy_deprecation_warning("Don't use",
  283. ... deprecated_since_version="1.0",
  284. ... active_deprecations_target="active-deprecations")
  285. >>> with warns_deprecated_sympy():
  286. ... pass
  287. Traceback (most recent call last):
  288. ...
  289. Failed: DID NOT WARN. No warnings of type \
  290. SymPyDeprecationWarning was emitted. The list of emitted warnings is: [].
  291. .. note::
  292. Sometimes the stacklevel test will fail because the same warning is
  293. emitted multiple times. In this case, you can use
  294. :func:`sympy.utilities.exceptions.ignore_warnings` in the code to
  295. prevent the ``SymPyDeprecationWarning`` from being emitted again
  296. recursively. In rare cases it is impossible to have a consistent
  297. ``stacklevel`` for deprecation warnings because different ways of
  298. calling a function will produce different call stacks.. In those cases,
  299. use ``warns(SymPyDeprecationWarning)`` instead.
  300. See Also
  301. ========
  302. sympy.utilities.exceptions.SymPyDeprecationWarning
  303. sympy.utilities.exceptions.sympy_deprecation_warning
  304. sympy.utilities.decorator.deprecated
  305. '''
  306. with warns(SymPyDeprecationWarning):
  307. yield
  308. def skip_under_pyodide(message):
  309. """Decorator to skip a test if running under Pyodide/WASM."""
  310. def decorator(test_func):
  311. @functools.wraps(test_func)
  312. def test_wrapper():
  313. if IS_WASM:
  314. skip(message)
  315. return test_func()
  316. return test_wrapper
  317. return decorator