test_ufunc_signatures.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. """Test that all ufuncs have float32-preserving signatures.
  2. This was once guaranteed through the code generation script for
  3. generating ufuncs, `scipy/special/_generate_pyx.py`. Starting with
  4. gh-20260, SciPy developers have begun moving to generate ufuncs
  5. through direct use of the NumPy C API (through C++). Existence of
  6. float32 preserving signatures must now be tested since it is no
  7. longer guaranteed.
  8. """
  9. import numpy as np
  10. import pytest
  11. import scipy.special._ufuncs
  12. import scipy.special._gufuncs
  13. # Single precision is not implemented for these ufuncs;
  14. # floating point inputs must be float64.
  15. exceptions = ['_gen_harmonic', '_normalized_gen_harmonic']
  16. _ufuncs = []
  17. for funcname in dir(scipy.special._ufuncs):
  18. if funcname not in exceptions:
  19. _ufuncs.append(getattr(scipy.special._ufuncs, funcname))
  20. for funcname in dir(scipy.special._gufuncs):
  21. _ufuncs.append(getattr(scipy.special._gufuncs, funcname))
  22. # Not all module members are actually ufuncs
  23. _ufuncs = [func for func in _ufuncs if isinstance(func, np.ufunc)]
  24. @pytest.mark.parametrize("ufunc", _ufuncs)
  25. def test_ufunc_signatures(ufunc):
  26. # From _generate_pyx.py
  27. # "Don't add float32 versions of ufuncs with integer arguments, as this
  28. # can lead to incorrect dtype selection if the integer arguments are
  29. # arrays, but float arguments are scalars.
  30. # This may be a NumPy bug, but we need to work around it.
  31. # cf. gh-4895, https://github.com/numpy/numpy/issues/5895"
  32. types = set(sig for sig in ufunc.types
  33. if not ("l" in sig or "i" in sig or "q" in sig or "p" in sig))
  34. # Generate the full expanded set of signatures which should exist. There
  35. # should be matching float and double versions of any existing signature.
  36. expanded_types = set()
  37. for sig in types:
  38. expanded_types.update(
  39. [sig.replace("d", "f").replace("D", "F"),
  40. sig.replace("f", "d").replace("F", "D")]
  41. )
  42. assert types == expanded_types