test_nan_inputs.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. """Test how the ufuncs in special handle nan inputs.
  2. """
  3. import warnings
  4. from collections.abc import Callable
  5. import numpy as np
  6. from numpy.testing import assert_array_equal, assert_
  7. import pytest
  8. import scipy.special as sc
  9. KNOWNFAILURES: dict[str, Callable] = {}
  10. POSTPROCESSING: dict[str, Callable] = {}
  11. def _get_ufuncs():
  12. ufuncs = []
  13. ufunc_names = []
  14. for name in sorted(sc.__dict__):
  15. obj = sc.__dict__[name]
  16. if not isinstance(obj, np.ufunc):
  17. continue
  18. msg = KNOWNFAILURES.get(obj)
  19. if msg is None:
  20. ufuncs.append(obj)
  21. ufunc_names.append(name)
  22. else:
  23. fail = pytest.mark.xfail(run=False, reason=msg)
  24. ufuncs.append(pytest.param(obj, marks=fail))
  25. ufunc_names.append(name)
  26. return ufuncs, ufunc_names
  27. UFUNCS, UFUNC_NAMES = _get_ufuncs()
  28. @pytest.mark.parametrize("func", UFUNCS, ids=UFUNC_NAMES)
  29. def test_nan_inputs(func):
  30. args = (np.nan,)*func.nin
  31. with warnings.catch_warnings():
  32. # Ignore warnings about unsafe casts from legacy wrappers
  33. warnings.filterwarnings(
  34. "ignore",
  35. "floating point number truncated to an integer",
  36. RuntimeWarning
  37. )
  38. try:
  39. with warnings.catch_warnings():
  40. warnings.simplefilter("ignore", DeprecationWarning)
  41. res = func(*args)
  42. except TypeError:
  43. # One of the arguments doesn't take real inputs
  44. return
  45. if func in POSTPROCESSING:
  46. res = POSTPROCESSING[func](*res)
  47. msg = f"got {res} instead of nan"
  48. assert_array_equal(np.isnan(res), True, err_msg=msg)
  49. def test_legacy_cast():
  50. with warnings.catch_warnings():
  51. warnings.filterwarnings(
  52. "ignore",
  53. "floating point number truncated to an integer",
  54. RuntimeWarning
  55. )
  56. res = sc.bdtrc(np.nan, 1, 0.5)
  57. assert_(np.isnan(res))