test_memory_async.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. import asyncio
  2. import gc
  3. import shutil
  4. import pytest
  5. from joblib.memory import (
  6. AsyncMemorizedFunc,
  7. AsyncNotMemorizedFunc,
  8. MemorizedResult,
  9. Memory,
  10. NotMemorizedResult,
  11. )
  12. from joblib.test.common import np, with_numpy
  13. from joblib.testing import raises
  14. from .test_memory import corrupt_single_cache_item, monkeypatch_cached_func_warn
  15. async def check_identity_lazy_async(func, accumulator, location):
  16. """Similar to check_identity_lazy_async for coroutine functions"""
  17. memory = Memory(location=location, verbose=0)
  18. func = memory.cache(func)
  19. for i in range(3):
  20. for _ in range(2):
  21. value = await func(i)
  22. assert value == i
  23. assert len(accumulator) == i + 1
  24. @pytest.mark.asyncio
  25. async def test_memory_integration_async(tmpdir):
  26. accumulator = list()
  27. async def f(n):
  28. await asyncio.sleep(0.1)
  29. accumulator.append(1)
  30. return n
  31. await check_identity_lazy_async(f, accumulator, tmpdir.strpath)
  32. # Now test clearing
  33. for compress in (False, True):
  34. for mmap_mode in ("r", None):
  35. memory = Memory(
  36. location=tmpdir.strpath,
  37. verbose=10,
  38. mmap_mode=mmap_mode,
  39. compress=compress,
  40. )
  41. # First clear the cache directory, to check that our code can
  42. # handle that
  43. # NOTE: this line would raise an exception, as the database
  44. # file is still open; we ignore the error since we want to
  45. # test what happens if the directory disappears
  46. shutil.rmtree(tmpdir.strpath, ignore_errors=True)
  47. g = memory.cache(f)
  48. await g(1)
  49. g.clear(warn=False)
  50. current_accumulator = len(accumulator)
  51. out = await g(1)
  52. assert len(accumulator) == current_accumulator + 1
  53. # Also, check that Memory.eval works similarly
  54. evaled = await memory.eval(f, 1)
  55. assert evaled == out
  56. assert len(accumulator) == current_accumulator + 1
  57. # Now do a smoke test with a function defined in __main__, as the name
  58. # mangling rules are more complex
  59. f.__module__ = "__main__"
  60. memory = Memory(location=tmpdir.strpath, verbose=0)
  61. await memory.cache(f)(1)
  62. @pytest.mark.asyncio
  63. async def test_no_memory_async():
  64. accumulator = list()
  65. async def ff(x):
  66. await asyncio.sleep(0.1)
  67. accumulator.append(1)
  68. return x
  69. memory = Memory(location=None, verbose=0)
  70. gg = memory.cache(ff)
  71. for _ in range(4):
  72. current_accumulator = len(accumulator)
  73. await gg(1)
  74. assert len(accumulator) == current_accumulator + 1
  75. @with_numpy
  76. @pytest.mark.asyncio
  77. async def test_memory_numpy_check_mmap_mode_async(tmpdir, monkeypatch):
  78. """Check that mmap_mode is respected even at the first call"""
  79. memory = Memory(location=tmpdir.strpath, mmap_mode="r", verbose=0)
  80. @memory.cache()
  81. async def twice(a):
  82. return a * 2
  83. a = np.ones(3)
  84. b = await twice(a)
  85. c = await twice(a)
  86. assert isinstance(c, np.memmap)
  87. assert c.mode == "r"
  88. assert isinstance(b, np.memmap)
  89. assert b.mode == "r"
  90. # Corrupts the file, Deleting b and c mmaps
  91. # is necessary to be able edit the file
  92. del b
  93. del c
  94. gc.collect()
  95. corrupt_single_cache_item(memory)
  96. # Make sure that corrupting the file causes recomputation and that
  97. # a warning is issued.
  98. recorded_warnings = monkeypatch_cached_func_warn(twice, monkeypatch)
  99. d = await twice(a)
  100. assert len(recorded_warnings) == 1
  101. exception_msg = "Exception while loading results"
  102. assert exception_msg in recorded_warnings[0]
  103. # Asserts that the recomputation returns a mmap
  104. assert isinstance(d, np.memmap)
  105. assert d.mode == "r"
  106. @pytest.mark.asyncio
  107. async def test_call_and_shelve_async(tmpdir):
  108. async def f(x, y=1):
  109. await asyncio.sleep(0.1)
  110. return x**2 + y
  111. # Test MemorizedFunc outputting a reference to cache.
  112. for func, Result in zip(
  113. (
  114. AsyncMemorizedFunc(f, tmpdir.strpath),
  115. AsyncNotMemorizedFunc(f),
  116. Memory(location=tmpdir.strpath, verbose=0).cache(f),
  117. Memory(location=None).cache(f),
  118. ),
  119. (
  120. MemorizedResult,
  121. NotMemorizedResult,
  122. MemorizedResult,
  123. NotMemorizedResult,
  124. ),
  125. ):
  126. for _ in range(2):
  127. result = await func.call_and_shelve(2)
  128. assert isinstance(result, Result)
  129. assert result.get() == 5
  130. result.clear()
  131. with raises(KeyError):
  132. result.get()
  133. result.clear() # Do nothing if there is no cache.
  134. @pytest.mark.asyncio
  135. async def test_memorized_func_call_async(memory):
  136. async def ff(x, counter):
  137. await asyncio.sleep(0.1)
  138. counter[x] = counter.get(x, 0) + 1
  139. return counter[x]
  140. gg = memory.cache(ff, ignore=["counter"])
  141. counter = {}
  142. assert await gg(2, counter) == 1
  143. assert await gg(2, counter) == 1
  144. x, meta = await gg.call(2, counter)
  145. assert x == 2, "f has not been called properly"
  146. assert isinstance(meta, dict), "Metadata are not returned by MemorizedFunc.call."