test_multithreading.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. from scipy import fft
  2. import numpy as np
  3. import pytest
  4. from numpy.testing import assert_allclose
  5. import multiprocessing
  6. import os
  7. @pytest.fixture(scope='module')
  8. def x():
  9. return np.random.randn(512, 128) # Must be large enough to qualify for mt
  10. @pytest.mark.parametrize("func", [
  11. fft.fft, fft.ifft, fft.fft2, fft.ifft2, fft.fftn, fft.ifftn,
  12. fft.rfft, fft.irfft, fft.rfft2, fft.irfft2, fft.rfftn, fft.irfftn,
  13. fft.hfft, fft.ihfft, fft.hfft2, fft.ihfft2, fft.hfftn, fft.ihfftn,
  14. fft.dct, fft.idct, fft.dctn, fft.idctn,
  15. fft.dst, fft.idst, fft.dstn, fft.idstn,
  16. ])
  17. @pytest.mark.parametrize("workers", [2, -1])
  18. def test_threaded_same(x, func, workers):
  19. expected = func(x, workers=1)
  20. actual = func(x, workers=workers)
  21. assert_allclose(actual, expected)
  22. def _mt_fft(x):
  23. return fft.fft(x, workers=2)
  24. @pytest.mark.slow
  25. def test_mixed_threads_processes(x):
  26. # Test that the fft threadpool is safe to use before & after fork
  27. expect = fft.fft(x, workers=2)
  28. with multiprocessing.Pool(2) as p:
  29. res = p.map(_mt_fft, [x for _ in range(4)])
  30. for r in res:
  31. assert_allclose(r, expect)
  32. fft.fft(x, workers=2)
  33. def test_invalid_workers(x):
  34. cpus = os.cpu_count()
  35. fft.ifft([1], workers=-cpus)
  36. with pytest.raises(ValueError, match='workers must not be zero'):
  37. fft.fft(x, workers=0)
  38. with pytest.raises(ValueError, match='workers value out of range'):
  39. fft.ifft(x, workers=-cpus-1)
  40. def test_set_get_workers():
  41. cpus = os.cpu_count()
  42. assert fft.get_workers() == 1
  43. with fft.set_workers(4):
  44. assert fft.get_workers() == 4
  45. with fft.set_workers(-1):
  46. assert fft.get_workers() == cpus
  47. assert fft.get_workers() == 4
  48. assert fft.get_workers() == 1
  49. with fft.set_workers(-cpus):
  50. assert fft.get_workers() == 1
  51. def test_set_workers_invalid():
  52. with pytest.raises(ValueError, match='workers must not be zero'):
  53. with fft.set_workers(0):
  54. pass
  55. with pytest.raises(ValueError, match='workers value out of range'):
  56. with fft.set_workers(-os.cpu_count()-1):
  57. pass