| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607 |
- from __future__ import absolute_import, division, print_function
- import os
- import warnings
- from random import random
- from time import sleep
- from uuid import uuid4
- import pytest
- from .. import Parallel, delayed, parallel_backend, parallel_config
- from .._dask import DaskDistributedBackend
- from ..parallel import AutoBatchingMixin, ThreadingBackend
- from .common import np, with_numpy
- from .test_parallel import (
- _recursive_backend_info,
- _test_deadlock_with_generator,
- _test_parallel_unordered_generator_returns_fastest_first, # noqa: E501
- )
- distributed = pytest.importorskip("distributed")
- dask = pytest.importorskip("dask")
- # These imports need to be after the pytest.importorskip hence the noqa: E402
- from distributed import Client, LocalCluster, get_client # noqa: E402
- from distributed.metrics import time # noqa: E402
- # Note: pytest requires to manually import all fixtures used in the test
- # and their dependencies.
- from distributed.utils_test import cleanup, cluster, inc # noqa: E402, F401
- @pytest.fixture(scope="function", autouse=True)
- def avoid_dask_env_leaks(tmp_path):
- # when starting a dask nanny, the environment variable might change.
- # this fixture makes sure the environment is reset after the test.
- from joblib._parallel_backends import ParallelBackendBase
- old_value = {k: os.environ.get(k) for k in ParallelBackendBase.MAX_NUM_THREADS_VARS}
- yield
- # Reset the environment variables to their original values
- for k, v in old_value.items():
- if v is None:
- os.environ.pop(k, None)
- else:
- os.environ[k] = v
- def noop(*args, **kwargs):
- pass
- def slow_raise_value_error(condition, duration=0.05):
- sleep(duration)
- if condition:
- raise ValueError("condition evaluated to True")
- def count_events(event_name, client):
- worker_events = client.run(lambda dask_worker: dask_worker.log)
- event_counts = {}
- for w, events in worker_events.items():
- event_counts[w] = len(
- [event for event in list(events) if event[1] == event_name]
- )
- return event_counts
- def test_simple(loop):
- with cluster() as (s, [a, b]):
- with Client(s["address"], loop=loop) as client: # noqa: F841
- with parallel_config(backend="dask"):
- seq = Parallel()(delayed(inc)(i) for i in range(10))
- assert seq == [inc(i) for i in range(10)]
- with pytest.raises(ValueError):
- Parallel()(
- delayed(slow_raise_value_error)(i == 3) for i in range(10)
- )
- seq = Parallel()(delayed(inc)(i) for i in range(10))
- assert seq == [inc(i) for i in range(10)]
- def test_dask_backend_uses_autobatching(loop):
- assert (
- DaskDistributedBackend.compute_batch_size
- is AutoBatchingMixin.compute_batch_size
- )
- with cluster() as (s, [a, b]):
- with Client(s["address"], loop=loop) as client: # noqa: F841
- with parallel_config(backend="dask"):
- with Parallel() as parallel:
- # The backend should be initialized with a default
- # batch size of 1:
- backend = parallel._backend
- assert isinstance(backend, DaskDistributedBackend)
- assert backend.parallel is parallel
- assert backend._effective_batch_size == 1
- # Launch many short tasks that should trigger
- # auto-batching:
- parallel(delayed(lambda: None)() for _ in range(int(1e4)))
- assert backend._effective_batch_size > 10
- @pytest.mark.parametrize("n_jobs", [2, -1])
- @pytest.mark.parametrize("context", [parallel_config, parallel_backend])
- def test_parallel_unordered_generator_returns_fastest_first_with_dask(n_jobs, context):
- with distributed.Client(n_workers=2, threads_per_worker=2), context("dask"):
- _test_parallel_unordered_generator_returns_fastest_first(None, n_jobs)
- @with_numpy
- @pytest.mark.parametrize("n_jobs", [2, -1])
- @pytest.mark.parametrize("return_as", ["generator", "generator_unordered"])
- @pytest.mark.parametrize("context", [parallel_config, parallel_backend])
- def test_deadlock_with_generator_and_dask(context, return_as, n_jobs):
- with distributed.Client(n_workers=2, threads_per_worker=2), context("dask"):
- _test_deadlock_with_generator(None, return_as, n_jobs)
- @with_numpy
- @pytest.mark.parametrize("context", [parallel_config, parallel_backend])
- def test_nested_parallelism_with_dask(context):
- with distributed.Client(n_workers=2, threads_per_worker=2):
- # 10 MB of data as argument to trigger implicit scattering
- data = np.ones(int(1e7), dtype=np.uint8)
- for i in range(2):
- with context("dask"):
- backend_types_and_levels = _recursive_backend_info(data=data)
- assert len(backend_types_and_levels) == 4
- assert all(
- name == "DaskDistributedBackend" for name, _ in backend_types_and_levels
- )
- # No argument
- with context("dask"):
- backend_types_and_levels = _recursive_backend_info()
- assert len(backend_types_and_levels) == 4
- assert all(
- name == "DaskDistributedBackend" for name, _ in backend_types_and_levels
- )
- def random2():
- return random()
- def test_dont_assume_function_purity(loop):
- with cluster() as (s, [a, b]):
- with Client(s["address"], loop=loop) as client: # noqa: F841
- with parallel_config(backend="dask"):
- x, y = Parallel()(delayed(random2)() for i in range(2))
- assert x != y
- @pytest.mark.parametrize("mixed", [True, False])
- def test_dask_funcname(loop, mixed):
- from joblib._dask import Batch
- if not mixed:
- tasks = [delayed(inc)(i) for i in range(4)]
- batch_repr = "batch_of_inc_4_calls"
- else:
- tasks = [delayed(abs)(i) if i % 2 else delayed(inc)(i) for i in range(4)]
- batch_repr = "mixed_batch_of_inc_4_calls"
- assert repr(Batch(tasks)) == batch_repr
- with cluster() as (s, [a, b]):
- with Client(s["address"], loop=loop) as client:
- with parallel_config(backend="dask"):
- _ = Parallel(batch_size=2, pre_dispatch="all")(tasks)
- def f(dask_scheduler):
- return list(dask_scheduler.transition_log)
- batch_repr = batch_repr.replace("4", "2")
- log = client.run_on_scheduler(f)
- assert all("batch_of_inc" in tup[0] for tup in log)
- def test_no_undesired_distributed_cache_hit():
- # Dask has a pickle cache for callables that are called many times. Because
- # the dask backends used to wrap both the functions and the arguments
- # under instances of the Batch callable class this caching mechanism could
- # lead to bugs as described in: https://github.com/joblib/joblib/pull/1055
- # The joblib-dask backend has been refactored to avoid bundling the
- # arguments as an attribute of the Batch instance to avoid this problem.
- # This test serves as non-regression problem.
- # Use a large number of input arguments to give the AutoBatchingMixin
- # enough tasks to kick-in.
- lists = [[] for _ in range(100)]
- np = pytest.importorskip("numpy")
- X = np.arange(int(1e6))
- def isolated_operation(list_, data=None):
- if data is not None:
- np.testing.assert_array_equal(data, X)
- list_.append(uuid4().hex)
- return list_
- cluster = LocalCluster(n_workers=1, threads_per_worker=2)
- client = Client(cluster)
- try:
- with parallel_config(backend="dask"):
- # dispatches joblib.parallel.BatchedCalls
- res = Parallel()(delayed(isolated_operation)(list_) for list_ in lists)
- # The original arguments should not have been mutated as the mutation
- # happens in the dask worker process.
- assert lists == [[] for _ in range(100)]
- # Here we did not pass any large numpy array as argument to
- # isolated_operation so no scattering event should happen under the
- # hood.
- counts = count_events("receive-from-scatter", client)
- assert sum(counts.values()) == 0
- assert all([len(r) == 1 for r in res])
- with parallel_config(backend="dask"):
- # Append a large array which will be scattered by dask, and
- # dispatch joblib._dask.Batch
- res = Parallel()(
- delayed(isolated_operation)(list_, data=X) for list_ in lists
- )
- # This time, auto-scattering should have kicked it.
- counts = count_events("receive-from-scatter", client)
- assert sum(counts.values()) > 0
- assert all([len(r) == 1 for r in res])
- finally:
- client.close(timeout=30)
- cluster.close(timeout=30)
- class CountSerialized(object):
- def __init__(self, x):
- self.x = x
- self.count = 0
- def __add__(self, other):
- return self.x + getattr(other, "x", other)
- __radd__ = __add__
- def __reduce__(self):
- self.count += 1
- return (CountSerialized, (self.x,))
- def add5(a, b, c, d=0, e=0):
- return a + b + c + d + e
- def test_manual_scatter(loop):
- # Let's check that the number of times scattered and non-scattered
- # variables are serialized is consistent between `joblib.Parallel` calls
- # and equivalent native `client.submit` call.
- # Number of serializations can vary from dask to another, so this test only
- # checks that `joblib.Parallel` does not add more serialization steps than
- # a native `client.submit` call, but does not check for an exact number of
- # serialization steps.
- w, x, y, z = (CountSerialized(i) for i in range(4))
- f = delayed(add5)
- tasks = [f(x, y, z, d=4, e=5) for _ in range(10)]
- tasks += [
- f(x, z, y, d=5, e=4),
- f(y, x, z, d=x, e=5),
- f(z, z, x, d=z, e=y),
- ]
- expected = [func(*args, **kwargs) for func, args, kwargs in tasks]
- with cluster() as (s, _):
- with Client(s["address"], loop=loop) as client: # noqa: F841
- with parallel_config(backend="dask", scatter=[w, x, y]):
- results_parallel = Parallel(batch_size=1)(tasks)
- assert results_parallel == expected
- # Check that an error is raised for bad arguments, as scatter must
- # take a list/tuple
- with pytest.raises(TypeError):
- with parallel_config(backend="dask", loop=loop, scatter=1):
- pass
- # Scattered variables only serialized during scatter. Checking with an
- # extra variable as this count can vary from one dask version
- # to another.
- n_serialization_scatter_with_parallel = w.count
- assert x.count == n_serialization_scatter_with_parallel
- assert y.count == n_serialization_scatter_with_parallel
- n_serialization_with_parallel = z.count
- # Reset the cluster and the serialization count
- for var in (w, x, y, z):
- var.count = 0
- with cluster() as (s, _):
- with Client(s["address"], loop=loop) as client: # noqa: F841
- scattered = dict()
- for obj in w, x, y:
- scattered[id(obj)] = client.scatter(obj, broadcast=True)
- results_native = [
- client.submit(
- func,
- *(scattered.get(id(arg), arg) for arg in args),
- **dict(
- (key, scattered.get(id(value), value))
- for (key, value) in kwargs.items()
- ),
- key=str(uuid4()),
- ).result()
- for (func, args, kwargs) in tasks
- ]
- assert results_native == expected
- # Now check that the number of serialization steps is the same for joblib
- # and native dask calls.
- n_serialization_scatter_native = w.count
- assert x.count == n_serialization_scatter_native
- assert y.count == n_serialization_scatter_native
- assert n_serialization_scatter_with_parallel == n_serialization_scatter_native
- distributed_version = tuple(int(v) for v in distributed.__version__.split("."))
- if distributed_version < (2023, 4):
- # Previous to 2023.4, the serialization was adding an extra call to
- # __reduce__ for the last job `f(z, z, x, d=z, e=y)`, because `z`
- # appears both in the args and kwargs, which is not the case when
- # running with joblib. Cope with this discrepancy.
- assert z.count == n_serialization_with_parallel + 1
- else:
- assert z.count == n_serialization_with_parallel
- # When the same IOLoop is used for multiple clients in a row, use
- # loop_in_thread instead of loop to prevent the Client from closing it. See
- # dask/distributed #4112
- def test_auto_scatter(loop_in_thread):
- np = pytest.importorskip("numpy")
- data1 = np.ones(int(1e4), dtype=np.uint8)
- data2 = np.ones(int(1e4), dtype=np.uint8)
- data_to_process = ([data1] * 3) + ([data2] * 3)
- with cluster() as (s, [a, b]):
- with Client(s["address"], loop=loop_in_thread) as client:
- with parallel_config(backend="dask"):
- # Passing the same data as arg and kwarg triggers a single
- # scatter operation whose result is reused.
- Parallel()(
- delayed(noop)(data, data, i, opt=data)
- for i, data in enumerate(data_to_process)
- )
- # By default large array are automatically scattered with
- # broadcast=1 which means that one worker must directly receive
- # the data from the scatter operation once.
- counts = count_events("receive-from-scatter", client)
- assert counts[a["address"]] + counts[b["address"]] == 2
- with cluster() as (s, [a, b]):
- with Client(s["address"], loop=loop_in_thread) as client:
- with parallel_config(backend="dask"):
- Parallel()(delayed(noop)(data1[:3], i) for i in range(5))
- # Small arrays are passed within the task definition without going
- # through a scatter operation.
- counts = count_events("receive-from-scatter", client)
- assert counts[a["address"]] == 0
- assert counts[b["address"]] == 0
- @pytest.mark.parametrize("retry_no", list(range(2)))
- def test_nested_scatter(loop, retry_no):
- np = pytest.importorskip("numpy")
- NUM_INNER_TASKS = 10
- NUM_OUTER_TASKS = 10
- def my_sum(x, i, j):
- return np.sum(x)
- def outer_function_joblib(array, i):
- client = get_client() # noqa
- with parallel_config(backend="dask"):
- results = Parallel()(
- delayed(my_sum)(array[j:], i, j) for j in range(NUM_INNER_TASKS)
- )
- return sum(results)
- with cluster() as (s, [a, b]):
- with Client(s["address"], loop=loop) as _:
- with parallel_config(backend="dask"):
- my_array = np.ones(10000)
- _ = Parallel()(
- delayed(outer_function_joblib)(my_array[i:], i)
- for i in range(NUM_OUTER_TASKS)
- )
- def test_nested_backend_context_manager(loop_in_thread):
- def get_nested_pids():
- pids = set(Parallel(n_jobs=2)(delayed(os.getpid)() for _ in range(2)))
- pids |= set(Parallel(n_jobs=2)(delayed(os.getpid)() for _ in range(2)))
- return pids
- with cluster() as (s, [a, b]):
- with Client(s["address"], loop=loop_in_thread) as client:
- with parallel_config(backend="dask"):
- pid_groups = Parallel(n_jobs=2)(
- delayed(get_nested_pids)() for _ in range(10)
- )
- for pid_group in pid_groups:
- assert len(set(pid_group)) <= 2
- # No deadlocks
- with Client(s["address"], loop=loop_in_thread) as client: # noqa: F841
- with parallel_config(backend="dask"):
- pid_groups = Parallel(n_jobs=2)(
- delayed(get_nested_pids)() for _ in range(10)
- )
- for pid_group in pid_groups:
- assert len(set(pid_group)) <= 2
- def test_nested_backend_context_manager_implicit_n_jobs(loop):
- # Check that Parallel with no explicit n_jobs value automatically selects
- # all the dask workers, including in nested calls.
- def _backend_type(p):
- return p._backend.__class__.__name__
- def get_nested_implicit_n_jobs():
- with Parallel() as p:
- return _backend_type(p), p.n_jobs
- with cluster() as (s, [a, b]):
- with Client(s["address"], loop=loop) as client: # noqa: F841
- with parallel_config(backend="dask"):
- with Parallel() as p:
- assert _backend_type(p) == "DaskDistributedBackend"
- assert p.n_jobs == -1
- all_nested_n_jobs = p(
- delayed(get_nested_implicit_n_jobs)() for _ in range(2)
- )
- for backend_type, nested_n_jobs in all_nested_n_jobs:
- assert backend_type == "DaskDistributedBackend"
- assert nested_n_jobs == -1
- def test_errors(loop):
- with pytest.raises(ValueError) as info:
- with parallel_config(backend="dask"):
- pass
- assert "create a dask client" in str(info.value).lower()
- def test_correct_nested_backend(loop):
- with cluster() as (s, [a, b]):
- with Client(s["address"], loop=loop) as client: # noqa: F841
- # No requirement, should be us
- with parallel_config(backend="dask"):
- result = Parallel(n_jobs=2)(
- delayed(outer)(nested_require=None) for _ in range(1)
- )
- assert isinstance(result[0][0][0], DaskDistributedBackend)
- # Require threads, should be threading
- with parallel_config(backend="dask"):
- result = Parallel(n_jobs=2)(
- delayed(outer)(nested_require="sharedmem") for _ in range(1)
- )
- assert isinstance(result[0][0][0], ThreadingBackend)
- def outer(nested_require):
- return Parallel(n_jobs=2, prefer="threads")(
- delayed(middle)(nested_require) for _ in range(1)
- )
- def middle(require):
- return Parallel(n_jobs=2, require=require)(delayed(inner)() for _ in range(1))
- def inner():
- return Parallel()._backend
- def test_secede_with_no_processes(loop):
- # https://github.com/dask/distributed/issues/1775
- with Client(loop=loop, processes=False, set_as_default=True):
- with parallel_config(backend="dask"):
- Parallel(n_jobs=4)(delayed(id)(i) for i in range(2))
- def _worker_address(_):
- from distributed import get_worker
- return get_worker().address
- def test_dask_backend_keywords(loop):
- with cluster() as (s, [a, b]):
- with Client(s["address"], loop=loop) as client: # noqa: F841
- with parallel_config(backend="dask", workers=a["address"]):
- seq = Parallel()(delayed(_worker_address)(i) for i in range(10))
- assert seq == [a["address"]] * 10
- with parallel_config(backend="dask", workers=b["address"]):
- seq = Parallel()(delayed(_worker_address)(i) for i in range(10))
- assert seq == [b["address"]] * 10
- def test_scheduler_tasks_cleanup(loop):
- with Client(processes=False, loop=loop) as client:
- with parallel_config(backend="dask"):
- Parallel()(delayed(inc)(i) for i in range(10))
- start = time()
- while client.cluster.scheduler.tasks:
- sleep(0.01)
- assert time() < start + 5
- assert not client.futures
- @pytest.mark.parametrize("cluster_strategy", ["adaptive", "late_scaling"])
- @pytest.mark.skipif(
- distributed.__version__ <= "2.1.1" and distributed.__version__ >= "1.28.0",
- reason="distributed bug - https://github.com/dask/distributed/pull/2841",
- )
- def test_wait_for_workers(cluster_strategy):
- cluster = LocalCluster(n_workers=0, processes=False, threads_per_worker=2)
- client = Client(cluster)
- if cluster_strategy == "adaptive":
- cluster.adapt(minimum=0, maximum=2)
- elif cluster_strategy == "late_scaling":
- # Tell the cluster to start workers but this is a non-blocking call
- # and new workers might take time to connect. In this case the Parallel
- # call should wait for at least one worker to come up before starting
- # to schedule work.
- cluster.scale(2)
- try:
- with parallel_config(backend="dask"):
- # The following should wait a bit for at least one worker to
- # become available.
- Parallel()(delayed(inc)(i) for i in range(10))
- finally:
- client.close()
- cluster.close()
- def test_wait_for_workers_timeout():
- # Start a cluster with 0 worker:
- cluster = LocalCluster(n_workers=0, processes=False, threads_per_worker=2)
- client = Client(cluster)
- try:
- with parallel_config(backend="dask", wait_for_workers_timeout=0.1):
- # Short timeout: DaskDistributedBackend
- msg = "DaskDistributedBackend has no worker after 0.1 seconds."
- with pytest.raises(TimeoutError, match=msg):
- Parallel()(delayed(inc)(i) for i in range(10))
- with parallel_config(backend="dask", wait_for_workers_timeout=0):
- # No timeout: fallback to generic joblib failure:
- msg = "DaskDistributedBackend has no active worker"
- with pytest.raises(RuntimeError, match=msg):
- Parallel()(delayed(inc)(i) for i in range(10))
- finally:
- client.close()
- cluster.close()
- @pytest.mark.parametrize("backend", ["loky", "multiprocessing"])
- def test_joblib_warning_inside_dask_daemonic_worker(backend):
- cluster = LocalCluster(n_workers=2)
- client = Client(cluster)
- try:
- def func_using_joblib_parallel():
- # Somehow trying to check the warning type here (e.g. with
- # pytest.warns(UserWarning)) make the test hang. Work-around:
- # return the warning record to the client and the warning check is
- # done client-side.
- with warnings.catch_warnings(record=True) as record:
- Parallel(n_jobs=2, backend=backend)(delayed(inc)(i) for i in range(10))
- return record
- fut = client.submit(func_using_joblib_parallel)
- record = fut.result()
- assert len(record) == 1
- warning = record[0].message
- assert isinstance(warning, UserWarning)
- assert "distributed.worker.daemon" in str(warning)
- finally:
- client.close(timeout=30)
- cluster.close(timeout=30)
|