| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180 |
- import asyncio
- import gc
- import shutil
- import pytest
- from joblib.memory import (
- AsyncMemorizedFunc,
- AsyncNotMemorizedFunc,
- MemorizedResult,
- Memory,
- NotMemorizedResult,
- )
- from joblib.test.common import np, with_numpy
- from joblib.testing import raises
- from .test_memory import corrupt_single_cache_item, monkeypatch_cached_func_warn
- async def check_identity_lazy_async(func, accumulator, location):
- """Similar to check_identity_lazy_async for coroutine functions"""
- memory = Memory(location=location, verbose=0)
- func = memory.cache(func)
- for i in range(3):
- for _ in range(2):
- value = await func(i)
- assert value == i
- assert len(accumulator) == i + 1
- @pytest.mark.asyncio
- async def test_memory_integration_async(tmpdir):
- accumulator = list()
- async def f(n):
- await asyncio.sleep(0.1)
- accumulator.append(1)
- return n
- await check_identity_lazy_async(f, accumulator, tmpdir.strpath)
- # Now test clearing
- for compress in (False, True):
- for mmap_mode in ("r", None):
- memory = Memory(
- location=tmpdir.strpath,
- verbose=10,
- mmap_mode=mmap_mode,
- compress=compress,
- )
- # First clear the cache directory, to check that our code can
- # handle that
- # NOTE: this line would raise an exception, as the database
- # file is still open; we ignore the error since we want to
- # test what happens if the directory disappears
- shutil.rmtree(tmpdir.strpath, ignore_errors=True)
- g = memory.cache(f)
- await g(1)
- g.clear(warn=False)
- current_accumulator = len(accumulator)
- out = await g(1)
- assert len(accumulator) == current_accumulator + 1
- # Also, check that Memory.eval works similarly
- evaled = await memory.eval(f, 1)
- assert evaled == out
- assert len(accumulator) == current_accumulator + 1
- # Now do a smoke test with a function defined in __main__, as the name
- # mangling rules are more complex
- f.__module__ = "__main__"
- memory = Memory(location=tmpdir.strpath, verbose=0)
- await memory.cache(f)(1)
- @pytest.mark.asyncio
- async def test_no_memory_async():
- accumulator = list()
- async def ff(x):
- await asyncio.sleep(0.1)
- accumulator.append(1)
- return x
- memory = Memory(location=None, verbose=0)
- gg = memory.cache(ff)
- for _ in range(4):
- current_accumulator = len(accumulator)
- await gg(1)
- assert len(accumulator) == current_accumulator + 1
- @with_numpy
- @pytest.mark.asyncio
- async def test_memory_numpy_check_mmap_mode_async(tmpdir, monkeypatch):
- """Check that mmap_mode is respected even at the first call"""
- memory = Memory(location=tmpdir.strpath, mmap_mode="r", verbose=0)
- @memory.cache()
- async def twice(a):
- return a * 2
- a = np.ones(3)
- b = await twice(a)
- c = await twice(a)
- assert isinstance(c, np.memmap)
- assert c.mode == "r"
- assert isinstance(b, np.memmap)
- assert b.mode == "r"
- # Corrupts the file, Deleting b and c mmaps
- # is necessary to be able edit the file
- del b
- del c
- gc.collect()
- corrupt_single_cache_item(memory)
- # Make sure that corrupting the file causes recomputation and that
- # a warning is issued.
- recorded_warnings = monkeypatch_cached_func_warn(twice, monkeypatch)
- d = await twice(a)
- assert len(recorded_warnings) == 1
- exception_msg = "Exception while loading results"
- assert exception_msg in recorded_warnings[0]
- # Asserts that the recomputation returns a mmap
- assert isinstance(d, np.memmap)
- assert d.mode == "r"
- @pytest.mark.asyncio
- async def test_call_and_shelve_async(tmpdir):
- async def f(x, y=1):
- await asyncio.sleep(0.1)
- return x**2 + y
- # Test MemorizedFunc outputting a reference to cache.
- for func, Result in zip(
- (
- AsyncMemorizedFunc(f, tmpdir.strpath),
- AsyncNotMemorizedFunc(f),
- Memory(location=tmpdir.strpath, verbose=0).cache(f),
- Memory(location=None).cache(f),
- ),
- (
- MemorizedResult,
- NotMemorizedResult,
- MemorizedResult,
- NotMemorizedResult,
- ),
- ):
- for _ in range(2):
- result = await func.call_and_shelve(2)
- assert isinstance(result, Result)
- assert result.get() == 5
- result.clear()
- with raises(KeyError):
- result.get()
- result.clear() # Do nothing if there is no cache.
- @pytest.mark.asyncio
- async def test_memorized_func_call_async(memory):
- async def ff(x, counter):
- await asyncio.sleep(0.1)
- counter[x] = counter.get(x, 0) + 1
- return counter[x]
- gg = memory.cache(ff, ignore=["counter"])
- counter = {}
- assert await gg(2, counter) == 1
- assert await gg(2, counter) == 1
- x, meta = await gg.call(2, counter)
- assert x == 2, "f has not been called properly"
- assert isinstance(meta, dict), "Metadata are not returned by MemorizedFunc.call."
|