test_config.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import os
  2. from joblib._parallel_backends import (
  3. LokyBackend,
  4. MultiprocessingBackend,
  5. ThreadingBackend,
  6. )
  7. from joblib.parallel import (
  8. BACKENDS,
  9. DEFAULT_BACKEND,
  10. EXTERNAL_BACKENDS,
  11. Parallel,
  12. delayed,
  13. parallel_backend,
  14. parallel_config,
  15. )
  16. from joblib.test.common import np, with_multiprocessing, with_numpy
  17. from joblib.test.test_parallel import check_memmap
  18. from joblib.testing import parametrize, raises
  19. @parametrize("context", [parallel_config, parallel_backend])
  20. def test_global_parallel_backend(context):
  21. default = Parallel()._backend
  22. pb = context("threading")
  23. try:
  24. assert isinstance(Parallel()._backend, ThreadingBackend)
  25. finally:
  26. pb.unregister()
  27. assert type(Parallel()._backend) is type(default)
  28. @parametrize("context", [parallel_config, parallel_backend])
  29. def test_external_backends(context):
  30. def register_foo():
  31. BACKENDS["foo"] = ThreadingBackend
  32. EXTERNAL_BACKENDS["foo"] = register_foo
  33. try:
  34. with context("foo"):
  35. assert isinstance(Parallel()._backend, ThreadingBackend)
  36. finally:
  37. del EXTERNAL_BACKENDS["foo"]
  38. @with_numpy
  39. @with_multiprocessing
  40. def test_parallel_config_no_backend(tmpdir):
  41. # Check that parallel_config allows to change the config
  42. # even if no backend is set.
  43. with parallel_config(n_jobs=2, max_nbytes=1, temp_folder=tmpdir):
  44. with Parallel(prefer="processes") as p:
  45. assert isinstance(p._backend, LokyBackend)
  46. assert p.n_jobs == 2
  47. # Checks that memmapping is enabled
  48. p(delayed(check_memmap)(a) for a in [np.random.random(10)] * 2)
  49. assert len(os.listdir(tmpdir)) > 0
  50. @with_numpy
  51. @with_multiprocessing
  52. def test_parallel_config_params_explicit_set(tmpdir):
  53. with parallel_config(n_jobs=3, max_nbytes=1, temp_folder=tmpdir):
  54. with Parallel(n_jobs=2, prefer="processes", max_nbytes="1M") as p:
  55. assert isinstance(p._backend, LokyBackend)
  56. assert p.n_jobs == 2
  57. # Checks that memmapping is disabled
  58. with raises(TypeError, match="Expected np.memmap instance"):
  59. p(delayed(check_memmap)(a) for a in [np.random.random(10)] * 2)
  60. @parametrize("param", ["prefer", "require"])
  61. def test_parallel_config_bad_params(param):
  62. # Check that an error is raised when setting a wrong backend
  63. # hint or constraint
  64. with raises(ValueError, match=f"{param}=wrong is not a valid"):
  65. with parallel_config(**{param: "wrong"}):
  66. Parallel()
  67. def test_parallel_config_constructor_params():
  68. # Check that an error is raised when backend is None
  69. # but backend constructor params are given
  70. with raises(ValueError, match="only supported when backend is not None"):
  71. with parallel_config(inner_max_num_threads=1):
  72. pass
  73. with raises(ValueError, match="only supported when backend is not None"):
  74. with parallel_config(backend_param=1):
  75. pass
  76. with raises(ValueError, match="only supported when backend is a string"):
  77. with parallel_config(backend=BACKENDS[DEFAULT_BACKEND], backend_param=1):
  78. pass
  79. def test_parallel_config_nested():
  80. # Check that nested configuration retrieves the info from the
  81. # parent config and do not reset them.
  82. with parallel_config(n_jobs=2):
  83. p = Parallel()
  84. assert isinstance(p._backend, BACKENDS[DEFAULT_BACKEND])
  85. assert p.n_jobs == 2
  86. with parallel_config(backend="threading"):
  87. with parallel_config(n_jobs=2):
  88. p = Parallel()
  89. assert isinstance(p._backend, ThreadingBackend)
  90. assert p.n_jobs == 2
  91. with parallel_config(verbose=100):
  92. with parallel_config(n_jobs=2):
  93. p = Parallel()
  94. assert p.verbose == 100
  95. assert p.n_jobs == 2
  96. @with_numpy
  97. @with_multiprocessing
  98. @parametrize(
  99. "backend",
  100. ["multiprocessing", "threading", MultiprocessingBackend(), ThreadingBackend()],
  101. )
  102. @parametrize("context", [parallel_config, parallel_backend])
  103. def test_threadpool_limitation_in_child_context_error(context, backend):
  104. with raises(AssertionError, match=r"does not acc.*inner_max_num_threads"):
  105. context(backend, inner_max_num_threads=1)
  106. @parametrize("context", [parallel_config, parallel_backend])
  107. def test_parallel_n_jobs_none(context):
  108. # Check that n_jobs=None is interpreted as "unset" in Parallel
  109. # non regression test for #1473
  110. with context(backend="threading", n_jobs=2):
  111. with Parallel(n_jobs=None) as p:
  112. assert p.n_jobs == 2
  113. with context(backend="threading"):
  114. default_n_jobs = Parallel().n_jobs
  115. with Parallel(n_jobs=None) as p:
  116. assert p.n_jobs == default_n_jobs
  117. @parametrize("context", [parallel_config, parallel_backend])
  118. def test_parallel_config_n_jobs_none(context):
  119. # Check that n_jobs=None is interpreted as "explicitly set" in
  120. # parallel_(config/backend)
  121. # non regression test for #1473
  122. with context(backend="threading", n_jobs=2):
  123. with context(backend="threading", n_jobs=None):
  124. # n_jobs=None resets n_jobs to backend's default
  125. with Parallel() as p:
  126. assert p.n_jobs == 1