test_memory.py 50 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582
  1. """
  2. Test the memory module.
  3. """
  4. # Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
  5. # Copyright (c) 2009 Gael Varoquaux
  6. # License: BSD Style, 3 clauses.
  7. import datetime
  8. import functools
  9. import gc
  10. import logging
  11. import os
  12. import os.path
  13. import pickle
  14. import shutil
  15. import sys
  16. import textwrap
  17. import time
  18. from pathlib import Path
  19. import pytest
  20. from joblib._store_backends import FileSystemStoreBackend, StoreBackendBase
  21. from joblib.hashing import hash
  22. from joblib.memory import (
  23. _FUNCTION_HASHES,
  24. _STORE_BACKENDS,
  25. JobLibCollisionWarning,
  26. MemorizedFunc,
  27. MemorizedResult,
  28. Memory,
  29. NotMemorizedFunc,
  30. NotMemorizedResult,
  31. _build_func_identifier,
  32. _store_backend_factory,
  33. expires_after,
  34. register_store_backend,
  35. )
  36. from joblib.parallel import Parallel, delayed
  37. from joblib.test.common import np, with_multiprocessing, with_numpy
  38. from joblib.testing import parametrize, raises, warns
  39. ###############################################################################
  40. # Module-level variables for the tests
  41. def f(x, y=1):
  42. """A module-level function for testing purposes."""
  43. return x**2 + y
  44. ###############################################################################
  45. # Helper function for the tests
  46. def check_identity_lazy(func, accumulator, location):
  47. """Given a function and an accumulator (a list that grows every
  48. time the function is called), check that the function can be
  49. decorated by memory to be a lazy identity.
  50. """
  51. # Call each function with several arguments, and check that it is
  52. # evaluated only once per argument.
  53. memory = Memory(location=location, verbose=0)
  54. func = memory.cache(func)
  55. for i in range(3):
  56. for _ in range(2):
  57. assert func(i) == i
  58. assert len(accumulator) == i + 1
  59. def corrupt_single_cache_item(memory):
  60. (single_cache_item,) = memory.store_backend.get_items()
  61. output_filename = os.path.join(single_cache_item.path, "output.pkl")
  62. with open(output_filename, "w") as f:
  63. f.write("garbage")
  64. def monkeypatch_cached_func_warn(func, monkeypatch_fixture):
  65. # Need monkeypatch because pytest does not
  66. # capture stdlib logging output (see
  67. # https://github.com/pytest-dev/pytest/issues/2079)
  68. recorded = []
  69. def append_to_record(item):
  70. recorded.append(item)
  71. monkeypatch_fixture.setattr(func, "warn", append_to_record)
  72. return recorded
  73. ###############################################################################
  74. # Tests
  75. def test_memory_integration(tmpdir):
  76. """Simple test of memory lazy evaluation."""
  77. accumulator = list()
  78. # Rmk: this function has the same name than a module-level function,
  79. # thus it serves as a test to see that both are identified
  80. # as different.
  81. def f(arg):
  82. accumulator.append(1)
  83. return arg
  84. check_identity_lazy(f, accumulator, tmpdir.strpath)
  85. # Now test clearing
  86. for compress in (False, True):
  87. for mmap_mode in ("r", None):
  88. memory = Memory(
  89. location=tmpdir.strpath,
  90. verbose=10,
  91. mmap_mode=mmap_mode,
  92. compress=compress,
  93. )
  94. # First clear the cache directory, to check that our code can
  95. # handle that
  96. # NOTE: this line would raise an exception, as the database file is
  97. # still open; we ignore the error since we want to test what
  98. # happens if the directory disappears
  99. shutil.rmtree(tmpdir.strpath, ignore_errors=True)
  100. g = memory.cache(f)
  101. g(1)
  102. g.clear(warn=False)
  103. current_accumulator = len(accumulator)
  104. out = g(1)
  105. assert len(accumulator) == current_accumulator + 1
  106. # Also, check that Memory.eval works similarly
  107. assert memory.eval(f, 1) == out
  108. assert len(accumulator) == current_accumulator + 1
  109. # Now do a smoke test with a function defined in __main__, as the name
  110. # mangling rules are more complex
  111. f.__module__ = "__main__"
  112. memory = Memory(location=tmpdir.strpath, verbose=0)
  113. memory.cache(f)(1)
  114. @parametrize("call_before_reducing", [True, False])
  115. def test_parallel_call_cached_function_defined_in_jupyter(tmpdir, call_before_reducing):
  116. # Calling an interactively defined memory.cache()'d function inside a
  117. # Parallel call used to clear the existing cache related to the said
  118. # function (https://github.com/joblib/joblib/issues/1035)
  119. # This tests checks that this is no longer the case.
  120. # TODO: test that the cache related to the function cache persists across
  121. # ipython sessions (provided that no code change were made to the
  122. # function's source)?
  123. # The first part of the test makes the necessary low-level calls to emulate
  124. # the definition of a function in an jupyter notebook cell. Joblib has
  125. # some custom code to treat functions defined specifically in jupyter
  126. # notebooks/ipython session -- we want to test this code, which requires
  127. # the emulation to be rigorous.
  128. for session_no in [0, 1]:
  129. ipython_cell_source = """
  130. def f(x):
  131. return x
  132. """
  133. ipython_cell_id = "<ipython-input-{}-000000000000>".format(session_no)
  134. my_locals = {}
  135. exec(
  136. compile(
  137. textwrap.dedent(ipython_cell_source),
  138. filename=ipython_cell_id,
  139. mode="exec",
  140. ),
  141. # TODO when Python 3.11 is the minimum supported version, use
  142. # locals=my_locals instead of passing globals and locals in the
  143. # next two lines as positional arguments
  144. None,
  145. my_locals,
  146. )
  147. f = my_locals["f"]
  148. f.__module__ = "__main__"
  149. # Preliminary sanity checks, and tests checking that joblib properly
  150. # identified f as an interactive function defined in a jupyter notebook
  151. assert f(1) == 1
  152. assert f.__code__.co_filename == ipython_cell_id
  153. memory = Memory(location=tmpdir.strpath, verbose=0)
  154. cached_f = memory.cache(f)
  155. assert len(os.listdir(tmpdir / "joblib")) == 1
  156. f_cache_relative_directory = os.listdir(tmpdir / "joblib")[0]
  157. assert "ipython-input" in f_cache_relative_directory
  158. f_cache_directory = tmpdir / "joblib" / f_cache_relative_directory
  159. if session_no == 0:
  160. # The cache should be empty as cached_f has not been called yet.
  161. assert os.listdir(f_cache_directory) == ["f"]
  162. assert os.listdir(f_cache_directory / "f") == []
  163. if call_before_reducing:
  164. cached_f(3)
  165. # Two files were just created, func_code.py, and a folder
  166. # containing the information (inputs hash/ouptput) of
  167. # cached_f(3)
  168. assert len(os.listdir(f_cache_directory / "f")) == 2
  169. # Now, testing #1035: when calling a cached function, joblib
  170. # used to dynamically inspect the underlying function to
  171. # extract its source code (to verify it matches the source code
  172. # of the function as last inspected by joblib) -- however,
  173. # source code introspection fails for dynamic functions sent to
  174. # child processes - which would eventually make joblib clear
  175. # the cache associated to f
  176. Parallel(n_jobs=2)(delayed(cached_f)(i) for i in [1, 2])
  177. else:
  178. # Submit the function to the joblib child processes, although
  179. # the function has never been called in the parent yet. This
  180. # triggers a specific code branch inside
  181. # MemorizedFunc.__reduce__.
  182. Parallel(n_jobs=2)(delayed(cached_f)(i) for i in [1, 2])
  183. # Ensure the child process has time to close the file.
  184. # Wait up to 5 seconds for slow CI runs
  185. for _ in range(25):
  186. if len(os.listdir(f_cache_directory / "f")) == 3:
  187. break
  188. time.sleep(0.2) # pragma: no cover
  189. assert len(os.listdir(f_cache_directory / "f")) == 3
  190. cached_f(3)
  191. # Making sure f's cache does not get cleared after the parallel
  192. # calls, and contains ALL cached functions calls (f(1), f(2), f(3))
  193. # and 'func_code.py'
  194. assert len(os.listdir(f_cache_directory / "f")) == 4
  195. else:
  196. # For the second session, there should be an already existing cache
  197. assert len(os.listdir(f_cache_directory / "f")) == 4
  198. cached_f(3)
  199. # The previous cache should not be invalidated after calling the
  200. # function in a new session
  201. assert len(os.listdir(f_cache_directory / "f")) == 4
  202. def test_no_memory():
  203. """Test memory with location=None: no memoize"""
  204. accumulator = list()
  205. def ff(arg):
  206. accumulator.append(1)
  207. return arg
  208. memory = Memory(location=None, verbose=0)
  209. gg = memory.cache(ff)
  210. for _ in range(4):
  211. current_accumulator = len(accumulator)
  212. gg(1)
  213. assert len(accumulator) == current_accumulator + 1
  214. def test_memory_kwarg(tmpdir):
  215. "Test memory with a function with keyword arguments."
  216. accumulator = list()
  217. def g(arg1=None, arg2=1):
  218. accumulator.append(1)
  219. return arg1
  220. check_identity_lazy(g, accumulator, tmpdir.strpath)
  221. memory = Memory(location=tmpdir.strpath, verbose=0)
  222. g = memory.cache(g)
  223. # Smoke test with an explicit keyword argument:
  224. assert g(arg1=30, arg2=2) == 30
  225. def test_memory_lambda(tmpdir):
  226. "Test memory with a function with a lambda."
  227. accumulator = list()
  228. def helper(x):
  229. """A helper function to define l as a lambda."""
  230. accumulator.append(1)
  231. return x
  232. check_identity_lazy(lambda x: helper(x), accumulator, tmpdir.strpath)
  233. def test_memory_name_collision(tmpdir):
  234. "Check that name collisions with functions will raise warnings"
  235. memory = Memory(location=tmpdir.strpath, verbose=0)
  236. @memory.cache
  237. def name_collision(x):
  238. """A first function called name_collision"""
  239. return x
  240. a = name_collision
  241. @memory.cache
  242. def name_collision(x):
  243. """A second function called name_collision"""
  244. return x
  245. b = name_collision
  246. with warns(JobLibCollisionWarning) as warninfo:
  247. a(1)
  248. b(1)
  249. assert len(warninfo) == 1
  250. assert "collision" in str(warninfo[0].message)
  251. def test_memory_warning_lambda_collisions(tmpdir):
  252. # Check that multiple use of lambda will raise collisions
  253. memory = Memory(location=tmpdir.strpath, verbose=0)
  254. a = memory.cache(lambda x: x)
  255. b = memory.cache(lambda x: x + 1)
  256. with warns(JobLibCollisionWarning) as warninfo:
  257. assert a(0) == 0
  258. assert b(1) == 2
  259. assert a(1) == 1
  260. # In recent Python versions, we can retrieve the code of lambdas,
  261. # thus nothing is raised
  262. assert len(warninfo) == 4
  263. def test_memory_warning_collision_detection(tmpdir):
  264. # Check that collisions impossible to detect will raise appropriate
  265. # warnings.
  266. memory = Memory(location=tmpdir.strpath, verbose=0)
  267. a1 = eval("lambda x: x")
  268. a1 = memory.cache(a1)
  269. b1 = eval("lambda x: x+1")
  270. b1 = memory.cache(b1)
  271. with warns(JobLibCollisionWarning) as warninfo:
  272. a1(1)
  273. b1(1)
  274. a1(0)
  275. assert len(warninfo) == 2
  276. assert "cannot detect" in str(warninfo[0].message).lower()
  277. def test_memory_partial(tmpdir):
  278. "Test memory with functools.partial."
  279. accumulator = list()
  280. def func(x, y):
  281. """A helper function to define l as a lambda."""
  282. accumulator.append(1)
  283. return y
  284. import functools
  285. function = functools.partial(func, 1)
  286. check_identity_lazy(function, accumulator, tmpdir.strpath)
  287. def test_memory_eval(tmpdir):
  288. "Smoke test memory with a function with a function defined in an eval."
  289. memory = Memory(location=tmpdir.strpath, verbose=0)
  290. m = eval("lambda x: x")
  291. mm = memory.cache(m)
  292. assert mm(1) == 1
  293. def count_and_append(x=[]):
  294. """A function with a side effect in its arguments.
  295. Return the length of its argument and append one element.
  296. """
  297. len_x = len(x)
  298. x.append(None)
  299. return len_x
  300. def test_argument_change(tmpdir):
  301. """Check that if a function has a side effect in its arguments, it
  302. should use the hash of changing arguments.
  303. """
  304. memory = Memory(location=tmpdir.strpath, verbose=0)
  305. func = memory.cache(count_and_append)
  306. # call the function for the first time, is should cache it with
  307. # argument x=[]
  308. assert func() == 0
  309. # the second time the argument is x=[None], which is not cached
  310. # yet, so the functions should be called a second time
  311. assert func() == 1
  312. @with_numpy
  313. @parametrize("mmap_mode", [None, "r"])
  314. def test_memory_numpy(tmpdir, mmap_mode):
  315. "Test memory with a function with numpy arrays."
  316. accumulator = list()
  317. def n(arg=None):
  318. accumulator.append(1)
  319. return arg
  320. memory = Memory(location=tmpdir.strpath, mmap_mode=mmap_mode, verbose=0)
  321. cached_n = memory.cache(n)
  322. rnd = np.random.RandomState(0)
  323. for i in range(3):
  324. a = rnd.random_sample((10, 10))
  325. for _ in range(3):
  326. assert np.all(cached_n(a) == a)
  327. assert len(accumulator) == i + 1
  328. @with_numpy
  329. def test_memory_numpy_check_mmap_mode(tmpdir, monkeypatch):
  330. """Check that mmap_mode is respected even at the first call"""
  331. memory = Memory(location=tmpdir.strpath, mmap_mode="r", verbose=0)
  332. @memory.cache()
  333. def twice(a):
  334. return a * 2
  335. a = np.ones(3)
  336. b = twice(a)
  337. c = twice(a)
  338. assert isinstance(c, np.memmap)
  339. assert c.mode == "r"
  340. assert isinstance(b, np.memmap)
  341. assert b.mode == "r"
  342. # Corrupts the file, Deleting b and c mmaps
  343. # is necessary to be able edit the file
  344. del b
  345. del c
  346. gc.collect()
  347. corrupt_single_cache_item(memory)
  348. # Make sure that corrupting the file causes recomputation and that
  349. # a warning is issued.
  350. recorded_warnings = monkeypatch_cached_func_warn(twice, monkeypatch)
  351. d = twice(a)
  352. assert len(recorded_warnings) == 1
  353. exception_msg = "Exception while loading results"
  354. assert exception_msg in recorded_warnings[0]
  355. # Asserts that the recomputation returns a mmap
  356. assert isinstance(d, np.memmap)
  357. assert d.mode == "r"
  358. def test_memory_exception(tmpdir):
  359. """Smoketest the exception handling of Memory."""
  360. memory = Memory(location=tmpdir.strpath, verbose=0)
  361. class MyException(Exception):
  362. pass
  363. @memory.cache
  364. def h(exc=0):
  365. if exc:
  366. raise MyException
  367. # Call once, to initialise the cache
  368. h()
  369. for _ in range(3):
  370. # Call 3 times, to be sure that the Exception is always raised
  371. with raises(MyException):
  372. h(1)
  373. def test_memory_ignore(tmpdir):
  374. "Test the ignore feature of memory"
  375. memory = Memory(location=tmpdir.strpath, verbose=0)
  376. accumulator = list()
  377. @memory.cache(ignore=["y"])
  378. def z(x, y=1):
  379. accumulator.append(1)
  380. assert z.ignore == ["y"]
  381. z(0, y=1)
  382. assert len(accumulator) == 1
  383. z(0, y=1)
  384. assert len(accumulator) == 1
  385. z(0, y=2)
  386. assert len(accumulator) == 1
  387. def test_memory_ignore_decorated(tmpdir):
  388. "Test the ignore feature of memory on a decorated function"
  389. memory = Memory(location=tmpdir.strpath, verbose=0)
  390. accumulator = list()
  391. def decorate(f):
  392. @functools.wraps(f)
  393. def wrapped(*args, **kwargs):
  394. return f(*args, **kwargs)
  395. return wrapped
  396. @memory.cache(ignore=["y"])
  397. @decorate
  398. def z(x, y=1):
  399. accumulator.append(1)
  400. assert z.ignore == ["y"]
  401. z(0, y=1)
  402. assert len(accumulator) == 1
  403. z(0, y=1)
  404. assert len(accumulator) == 1
  405. z(0, y=2)
  406. assert len(accumulator) == 1
  407. def test_memory_args_as_kwargs(tmpdir):
  408. """Non-regression test against 0.12.0 changes.
  409. https://github.com/joblib/joblib/pull/751
  410. """
  411. memory = Memory(location=tmpdir.strpath, verbose=0)
  412. @memory.cache
  413. def plus_one(a):
  414. return a + 1
  415. # It's possible to call a positional arg as a kwarg.
  416. assert plus_one(1) == 2
  417. assert plus_one(a=1) == 2
  418. # However, a positional argument that joblib hadn't seen
  419. # before would cause a failure if it was passed as a kwarg.
  420. assert plus_one(a=2) == 3
  421. @parametrize("ignore, verbose, mmap_mode", [(["x"], 100, "r"), ([], 10, None)])
  422. def test_partial_decoration(tmpdir, ignore, verbose, mmap_mode):
  423. "Check cache may be called with kwargs before decorating"
  424. memory = Memory(location=tmpdir.strpath, verbose=0)
  425. @memory.cache(ignore=ignore, verbose=verbose, mmap_mode=mmap_mode)
  426. def z(x):
  427. pass
  428. assert z.ignore == ignore
  429. assert z._verbose == verbose
  430. assert z.mmap_mode == mmap_mode
  431. def test_func_dir(tmpdir):
  432. # Test the creation of the memory cache directory for the function.
  433. memory = Memory(location=tmpdir.strpath, verbose=0)
  434. path = __name__.split(".")
  435. path.append("f")
  436. path = tmpdir.join("joblib", *path).strpath
  437. g = memory.cache(f)
  438. # Test that the function directory is created on demand
  439. func_id = _build_func_identifier(f)
  440. location = os.path.join(g.store_backend.location, func_id)
  441. assert location == path
  442. assert os.path.exists(path)
  443. assert memory.location == os.path.dirname(g.store_backend.location)
  444. # Test that the code is stored.
  445. # For the following test to be robust to previous execution, we clear
  446. # the in-memory store
  447. _FUNCTION_HASHES.clear()
  448. assert not g._check_previous_func_code()
  449. assert os.path.exists(os.path.join(path, "func_code.py"))
  450. assert g._check_previous_func_code()
  451. # Test the robustness to failure of loading previous results.
  452. args_id = g._get_args_id(1)
  453. output_dir = os.path.join(g.store_backend.location, g.func_id, args_id)
  454. a = g(1)
  455. assert os.path.exists(output_dir)
  456. os.remove(os.path.join(output_dir, "output.pkl"))
  457. assert a == g(1)
  458. def test_persistence(tmpdir):
  459. # Test the memorized functions can be pickled and restored.
  460. memory = Memory(location=tmpdir.strpath, verbose=0)
  461. g = memory.cache(f)
  462. output = g(1)
  463. h = pickle.loads(pickle.dumps(g))
  464. args_id = h._get_args_id(1)
  465. output_dir = os.path.join(h.store_backend.location, h.func_id, args_id)
  466. assert os.path.exists(output_dir)
  467. assert output == h.store_backend.load_item([h.func_id, args_id])
  468. memory2 = pickle.loads(pickle.dumps(memory))
  469. assert memory.store_backend.location == memory2.store_backend.location
  470. # Smoke test that pickling a memory with location=None works
  471. memory = Memory(location=None, verbose=0)
  472. pickle.loads(pickle.dumps(memory))
  473. g = memory.cache(f)
  474. gp = pickle.loads(pickle.dumps(g))
  475. gp(1)
  476. @pytest.mark.parametrize("consider_cache_valid", [True, False])
  477. def test_check_call_in_cache(tmpdir, consider_cache_valid):
  478. for func in (
  479. MemorizedFunc(
  480. f, tmpdir.strpath, cache_validation_callback=lambda _: consider_cache_valid
  481. ),
  482. Memory(location=tmpdir.strpath, verbose=0).cache(
  483. f, cache_validation_callback=lambda _: consider_cache_valid
  484. ),
  485. ):
  486. result = func.check_call_in_cache(2)
  487. assert isinstance(result, bool)
  488. assert not result
  489. assert func(2) == 5
  490. result = func.check_call_in_cache(2)
  491. assert isinstance(result, bool)
  492. assert result == consider_cache_valid
  493. func.clear()
  494. func = NotMemorizedFunc(f)
  495. assert not func.check_call_in_cache(2)
  496. def test_call_and_shelve(tmpdir):
  497. # Test MemorizedFunc outputting a reference to cache.
  498. for func, Result in zip(
  499. (
  500. MemorizedFunc(f, tmpdir.strpath),
  501. NotMemorizedFunc(f),
  502. Memory(location=tmpdir.strpath, verbose=0).cache(f),
  503. Memory(location=None).cache(f),
  504. ),
  505. (MemorizedResult, NotMemorizedResult, MemorizedResult, NotMemorizedResult),
  506. ):
  507. assert func(2) == 5
  508. result = func.call_and_shelve(2)
  509. assert isinstance(result, Result)
  510. assert result.get() == 5
  511. result.clear()
  512. with raises(KeyError):
  513. result.get()
  514. result.clear() # Do nothing if there is no cache.
  515. def test_call_and_shelve_lazily_load_stored_result(tmpdir):
  516. """Check call_and_shelve only load stored data if needed."""
  517. test_access_time_file = tmpdir.join("test_access")
  518. test_access_time_file.write("test_access")
  519. test_access_time = os.stat(test_access_time_file.strpath).st_atime
  520. # check file system access time stats resolution is lower than test wait
  521. # timings.
  522. time.sleep(0.5)
  523. assert test_access_time_file.read() == "test_access"
  524. if test_access_time == os.stat(test_access_time_file.strpath).st_atime:
  525. # Skip this test when access time cannot be retrieved with enough
  526. # precision from the file system (e.g. NTFS on windows).
  527. pytest.skip("filesystem does not support fine-grained access time attribute")
  528. memory = Memory(location=tmpdir.strpath, verbose=0)
  529. func = memory.cache(f)
  530. args_id = func._get_args_id(2)
  531. result_path = os.path.join(
  532. memory.store_backend.location, func.func_id, args_id, "output.pkl"
  533. )
  534. assert func(2) == 5
  535. first_access_time = os.stat(result_path).st_atime
  536. time.sleep(1)
  537. # Should not access the stored data
  538. result = func.call_and_shelve(2)
  539. assert isinstance(result, MemorizedResult)
  540. assert os.stat(result_path).st_atime == first_access_time
  541. time.sleep(1)
  542. # Read the stored data => last access time is greater than first_access
  543. assert result.get() == 5
  544. assert os.stat(result_path).st_atime > first_access_time
  545. def test_memorized_pickling(tmpdir):
  546. for func in (MemorizedFunc(f, tmpdir.strpath), NotMemorizedFunc(f)):
  547. filename = tmpdir.join("pickling_test.dat").strpath
  548. result = func.call_and_shelve(2)
  549. with open(filename, "wb") as fp:
  550. pickle.dump(result, fp)
  551. with open(filename, "rb") as fp:
  552. result2 = pickle.load(fp)
  553. assert result2.get() == result.get()
  554. os.remove(filename)
  555. def test_memorized_repr(tmpdir):
  556. func = MemorizedFunc(f, tmpdir.strpath)
  557. result = func.call_and_shelve(2)
  558. func2 = MemorizedFunc(f, tmpdir.strpath)
  559. result2 = func2.call_and_shelve(2)
  560. assert result.get() == result2.get()
  561. assert repr(func) == repr(func2)
  562. # Smoke test with NotMemorizedFunc
  563. func = NotMemorizedFunc(f)
  564. repr(func)
  565. repr(func.call_and_shelve(2))
  566. # Smoke test for message output (increase code coverage)
  567. func = MemorizedFunc(f, tmpdir.strpath, verbose=11, timestamp=time.time())
  568. result = func.call_and_shelve(11)
  569. result.get()
  570. func = MemorizedFunc(f, tmpdir.strpath, verbose=11)
  571. result = func.call_and_shelve(11)
  572. result.get()
  573. func = MemorizedFunc(f, tmpdir.strpath, verbose=5, timestamp=time.time())
  574. result = func.call_and_shelve(11)
  575. result.get()
  576. func = MemorizedFunc(f, tmpdir.strpath, verbose=5)
  577. result = func.call_and_shelve(11)
  578. result.get()
  579. def test_memory_file_modification(capsys, tmpdir, monkeypatch):
  580. # Test that modifying a Python file after loading it does not lead to
  581. # Recomputation
  582. dir_name = tmpdir.mkdir("tmp_import").strpath
  583. filename = os.path.join(dir_name, "tmp_joblib_.py")
  584. content = "def f(x):\n print(x)\n return x\n"
  585. with open(filename, "w") as module_file:
  586. module_file.write(content)
  587. # Load the module:
  588. monkeypatch.syspath_prepend(dir_name)
  589. import tmp_joblib_ as tmp
  590. memory = Memory(location=tmpdir.strpath, verbose=0)
  591. f = memory.cache(tmp.f)
  592. # First call f a few times
  593. f(1)
  594. f(2)
  595. f(1)
  596. # Now modify the module where f is stored without modifying f
  597. with open(filename, "w") as module_file:
  598. module_file.write("\n\n" + content)
  599. # And call f a couple more times
  600. f(1)
  601. f(1)
  602. # Flush the .pyc files
  603. shutil.rmtree(dir_name)
  604. os.mkdir(dir_name)
  605. # Now modify the module where f is stored, modifying f
  606. content = 'def f(x):\n print("x=%s" % x)\n return x\n'
  607. with open(filename, "w") as module_file:
  608. module_file.write(content)
  609. # And call f more times prior to reloading: the cache should not be
  610. # invalidated at this point as the active function definition has not
  611. # changed in memory yet.
  612. f(1)
  613. f(1)
  614. # Now reload
  615. sys.stdout.write("Reloading\n")
  616. sys.modules.pop("tmp_joblib_")
  617. import tmp_joblib_ as tmp
  618. f = memory.cache(tmp.f)
  619. # And call f more times
  620. f(1)
  621. f(1)
  622. out, err = capsys.readouterr()
  623. assert out == "1\n2\nReloading\nx=1\n"
  624. def _function_to_cache(a, b):
  625. # Just a place holder function to be mutated by tests
  626. pass
  627. def _sum(a, b):
  628. return a + b
  629. def _product(a, b):
  630. return a * b
  631. def test_memory_in_memory_function_code_change(tmpdir):
  632. _function_to_cache.__code__ = _sum.__code__
  633. memory = Memory(location=tmpdir.strpath, verbose=0)
  634. f = memory.cache(_function_to_cache)
  635. assert f(1, 2) == 3
  636. assert f(1, 2) == 3
  637. with warns(JobLibCollisionWarning):
  638. # Check that inline function modification triggers a cache invalidation
  639. _function_to_cache.__code__ = _product.__code__
  640. assert f(1, 2) == 2
  641. assert f(1, 2) == 2
  642. def test_clear_memory_with_none_location():
  643. memory = Memory(location=None)
  644. memory.clear()
  645. def func_with_kwonly_args(a, b, *, kw1="kw1", kw2="kw2"):
  646. return a, b, kw1, kw2
  647. def func_with_signature(a: int, b: float) -> float:
  648. return a + b
  649. def test_memory_func_with_kwonly_args(tmpdir):
  650. memory = Memory(location=tmpdir.strpath, verbose=0)
  651. func_cached = memory.cache(func_with_kwonly_args)
  652. assert func_cached(1, 2, kw1=3) == (1, 2, 3, "kw2")
  653. # Making sure that providing a keyword-only argument by
  654. # position raises an exception
  655. with raises(ValueError) as excinfo:
  656. func_cached(1, 2, 3, kw2=4)
  657. excinfo.match("Keyword-only parameter 'kw1' was passed as positional parameter")
  658. # Keyword-only parameter passed by position with cached call
  659. # should still raise ValueError
  660. func_cached(1, 2, kw1=3, kw2=4)
  661. with raises(ValueError) as excinfo:
  662. func_cached(1, 2, 3, kw2=4)
  663. excinfo.match("Keyword-only parameter 'kw1' was passed as positional parameter")
  664. # Test 'ignore' parameter
  665. func_cached = memory.cache(func_with_kwonly_args, ignore=["kw2"])
  666. assert func_cached(1, 2, kw1=3, kw2=4) == (1, 2, 3, 4)
  667. assert func_cached(1, 2, kw1=3, kw2="ignored") == (1, 2, 3, 4)
  668. def test_memory_func_with_signature(tmpdir):
  669. memory = Memory(location=tmpdir.strpath, verbose=0)
  670. func_cached = memory.cache(func_with_signature)
  671. assert func_cached(1, 2.0) == 3.0
  672. def _setup_toy_cache(tmpdir, num_inputs=10):
  673. memory = Memory(location=tmpdir.strpath, verbose=0)
  674. @memory.cache()
  675. def get_1000_bytes(arg):
  676. return "a" * 1000
  677. inputs = list(range(num_inputs))
  678. for arg in inputs:
  679. get_1000_bytes(arg)
  680. func_id = _build_func_identifier(get_1000_bytes)
  681. hash_dirnames = [get_1000_bytes._get_args_id(arg) for arg in inputs]
  682. full_hashdirs = [
  683. os.path.join(get_1000_bytes.store_backend.location, func_id, dirname)
  684. for dirname in hash_dirnames
  685. ]
  686. return memory, full_hashdirs, get_1000_bytes
  687. def test__get_items(tmpdir):
  688. memory, expected_hash_dirs, _ = _setup_toy_cache(tmpdir)
  689. items = memory.store_backend.get_items()
  690. hash_dirs = [ci.path for ci in items]
  691. assert set(hash_dirs) == set(expected_hash_dirs)
  692. def get_files_size(directory):
  693. full_paths = [os.path.join(directory, fn) for fn in os.listdir(directory)]
  694. return sum(os.path.getsize(fp) for fp in full_paths)
  695. expected_hash_cache_sizes = [get_files_size(hash_dir) for hash_dir in hash_dirs]
  696. hash_cache_sizes = [ci.size for ci in items]
  697. assert hash_cache_sizes == expected_hash_cache_sizes
  698. output_filenames = [os.path.join(hash_dir, "output.pkl") for hash_dir in hash_dirs]
  699. expected_last_accesses = [
  700. datetime.datetime.fromtimestamp(os.path.getatime(fn)) for fn in output_filenames
  701. ]
  702. last_accesses = [ci.last_access for ci in items]
  703. assert last_accesses == expected_last_accesses
  704. def test__get_items_to_delete(tmpdir):
  705. # test empty cache
  706. memory, _, _ = _setup_toy_cache(tmpdir, num_inputs=0)
  707. items_to_delete = memory.store_backend._get_items_to_delete("1K")
  708. assert items_to_delete == []
  709. memory, expected_hash_cachedirs, _ = _setup_toy_cache(tmpdir)
  710. items = memory.store_backend.get_items()
  711. # bytes_limit set to keep only one cache item (each hash cache
  712. # folder is about 1000 bytes + metadata)
  713. items_to_delete = memory.store_backend._get_items_to_delete("2K")
  714. nb_hashes = len(expected_hash_cachedirs)
  715. assert set.issubset(set(items_to_delete), set(items))
  716. assert len(items_to_delete) == nb_hashes - 1
  717. # Sanity check bytes_limit=2048 is the same as bytes_limit='2K'
  718. items_to_delete_2048b = memory.store_backend._get_items_to_delete(2048)
  719. assert sorted(items_to_delete) == sorted(items_to_delete_2048b)
  720. # bytes_limit greater than the size of the cache
  721. items_to_delete_empty = memory.store_backend._get_items_to_delete("1M")
  722. assert items_to_delete_empty == []
  723. # All the cache items need to be deleted
  724. bytes_limit_too_small = 500
  725. items_to_delete_500b = memory.store_backend._get_items_to_delete(
  726. bytes_limit_too_small
  727. )
  728. assert set(items_to_delete_500b), set(items)
  729. # Test LRU property: surviving cache items should all have a more
  730. # recent last_access that the ones that have been deleted
  731. items_to_delete_6000b = memory.store_backend._get_items_to_delete(6000)
  732. surviving_items = set(items).difference(items_to_delete_6000b)
  733. assert max(ci.last_access for ci in items_to_delete_6000b) <= min(
  734. ci.last_access for ci in surviving_items
  735. )
  736. def test_memory_reduce_size_bytes_limit(tmpdir):
  737. memory, _, _ = _setup_toy_cache(tmpdir)
  738. ref_cache_items = memory.store_backend.get_items()
  739. # By default memory.bytes_limit is None and reduce_size is a noop
  740. memory.reduce_size()
  741. cache_items = memory.store_backend.get_items()
  742. assert sorted(ref_cache_items) == sorted(cache_items)
  743. # No cache items deleted if bytes_limit greater than the size of
  744. # the cache
  745. memory.reduce_size(bytes_limit="1M")
  746. cache_items = memory.store_backend.get_items()
  747. assert sorted(ref_cache_items) == sorted(cache_items)
  748. # bytes_limit is set so that only two cache items are kept
  749. memory.reduce_size(bytes_limit="3K")
  750. cache_items = memory.store_backend.get_items()
  751. assert set.issubset(set(cache_items), set(ref_cache_items))
  752. assert len(cache_items) == 2
  753. # bytes_limit set so that no cache item is kept
  754. bytes_limit_too_small = 500
  755. memory.reduce_size(bytes_limit=bytes_limit_too_small)
  756. cache_items = memory.store_backend.get_items()
  757. assert cache_items == []
  758. def test_memory_reduce_size_items_limit(tmpdir):
  759. memory, _, _ = _setup_toy_cache(tmpdir)
  760. ref_cache_items = memory.store_backend.get_items()
  761. # By default reduce_size is a noop
  762. memory.reduce_size()
  763. cache_items = memory.store_backend.get_items()
  764. assert sorted(ref_cache_items) == sorted(cache_items)
  765. # No cache items deleted if items_limit greater than the size of
  766. # the cache
  767. memory.reduce_size(items_limit=10)
  768. cache_items = memory.store_backend.get_items()
  769. assert sorted(ref_cache_items) == sorted(cache_items)
  770. # items_limit is set so that only two cache items are kept
  771. memory.reduce_size(items_limit=2)
  772. cache_items = memory.store_backend.get_items()
  773. assert set.issubset(set(cache_items), set(ref_cache_items))
  774. assert len(cache_items) == 2
  775. # item_limit set so that no cache item is kept
  776. memory.reduce_size(items_limit=0)
  777. cache_items = memory.store_backend.get_items()
  778. assert cache_items == []
  779. def test_memory_reduce_size_age_limit(tmpdir):
  780. import datetime
  781. import time
  782. memory, _, put_cache = _setup_toy_cache(tmpdir)
  783. ref_cache_items = memory.store_backend.get_items()
  784. # By default reduce_size is a noop
  785. memory.reduce_size()
  786. cache_items = memory.store_backend.get_items()
  787. assert sorted(ref_cache_items) == sorted(cache_items)
  788. # No cache items deleted if age_limit big.
  789. memory.reduce_size(age_limit=datetime.timedelta(days=1))
  790. cache_items = memory.store_backend.get_items()
  791. assert sorted(ref_cache_items) == sorted(cache_items)
  792. # age_limit is set so that only two cache items are kept
  793. time.sleep(1)
  794. put_cache(-1)
  795. put_cache(-2)
  796. memory.reduce_size(age_limit=datetime.timedelta(seconds=1))
  797. cache_items = memory.store_backend.get_items()
  798. assert not set.issubset(set(cache_items), set(ref_cache_items))
  799. assert len(cache_items) == 2
  800. # ensure age_limit is forced to be positive
  801. with pytest.raises(ValueError, match="has to be a positive"):
  802. memory.reduce_size(age_limit=datetime.timedelta(seconds=-1))
  803. # age_limit set so that no cache item is kept
  804. time.sleep(0.001) # make sure the age is different
  805. memory.reduce_size(age_limit=datetime.timedelta(seconds=0))
  806. cache_items = memory.store_backend.get_items()
  807. assert cache_items == []
  808. def test_memory_clear(tmpdir):
  809. memory, _, g = _setup_toy_cache(tmpdir)
  810. memory.clear()
  811. assert os.listdir(memory.store_backend.location) == []
  812. # Check that the cache for functions hash is also reset.
  813. assert not g._check_previous_func_code(stacklevel=4)
  814. def fast_func_with_complex_output():
  815. complex_obj = ["a" * 1000] * 1000
  816. return complex_obj
  817. def fast_func_with_conditional_complex_output(complex_output=True):
  818. complex_obj = {str(i): i for i in range(int(1e5))}
  819. return complex_obj if complex_output else "simple output"
  820. @with_multiprocessing
  821. def test_cached_function_race_condition_when_persisting_output(tmpdir, capfd):
  822. # Test race condition where multiple processes are writing into
  823. # the same output.pkl. See
  824. # https://github.com/joblib/joblib/issues/490 for more details.
  825. memory = Memory(location=tmpdir.strpath)
  826. func_cached = memory.cache(fast_func_with_complex_output)
  827. Parallel(n_jobs=2)(delayed(func_cached)() for i in range(3))
  828. stdout, stderr = capfd.readouterr()
  829. # Checking both stdout and stderr (ongoing PR #434 may change
  830. # logging destination) to make sure there is no exception while
  831. # loading the results
  832. exception_msg = "Exception while loading results"
  833. assert exception_msg not in stdout
  834. assert exception_msg not in stderr
  835. @with_multiprocessing
  836. def test_cached_function_race_condition_when_persisting_output_2(tmpdir, capfd):
  837. # Test race condition in first attempt at solving
  838. # https://github.com/joblib/joblib/issues/490. The race condition
  839. # was due to the delay between seeing the cache directory created
  840. # (interpreted as the result being cached) and the output.pkl being
  841. # pickled.
  842. memory = Memory(location=tmpdir.strpath)
  843. func_cached = memory.cache(fast_func_with_conditional_complex_output)
  844. Parallel(n_jobs=2)(
  845. delayed(func_cached)(True if i % 2 == 0 else False) for i in range(3)
  846. )
  847. stdout, stderr = capfd.readouterr()
  848. # Checking both stdout and stderr (ongoing PR #434 may change
  849. # logging destination) to make sure there is no exception while
  850. # loading the results
  851. exception_msg = "Exception while loading results"
  852. assert exception_msg not in stdout
  853. assert exception_msg not in stderr
  854. def test_memory_recomputes_after_an_error_while_loading_results(tmpdir, monkeypatch):
  855. memory = Memory(location=tmpdir.strpath)
  856. def func(arg):
  857. # This makes sure that the timestamp returned by two calls of
  858. # func are different. This is needed on Windows where
  859. # time.time resolution may not be accurate enough
  860. time.sleep(0.01)
  861. return arg, time.time()
  862. cached_func = memory.cache(func)
  863. input_arg = "arg"
  864. arg, timestamp = cached_func(input_arg)
  865. # Make sure the function is correctly cached
  866. assert arg == input_arg
  867. # Corrupting output.pkl to make sure that an error happens when
  868. # loading the cached result
  869. corrupt_single_cache_item(memory)
  870. # Make sure that corrupting the file causes recomputation and that
  871. # a warning is issued.
  872. recorded_warnings = monkeypatch_cached_func_warn(cached_func, monkeypatch)
  873. recomputed_arg, recomputed_timestamp = cached_func(arg)
  874. assert len(recorded_warnings) == 1
  875. exception_msg = "Exception while loading results"
  876. assert exception_msg in recorded_warnings[0]
  877. assert recomputed_arg == arg
  878. assert recomputed_timestamp > timestamp
  879. # Corrupting output.pkl to make sure that an error happens when
  880. # loading the cached result
  881. corrupt_single_cache_item(memory)
  882. reference = cached_func.call_and_shelve(arg)
  883. try:
  884. reference.get()
  885. raise AssertionError(
  886. "It normally not possible to load a corrupted MemorizedResult"
  887. )
  888. except KeyError as e:
  889. message = "is corrupted"
  890. assert message in str(e.args)
  891. class IncompleteStoreBackend(StoreBackendBase):
  892. """This backend cannot be instantiated and should raise a TypeError."""
  893. pass
  894. class DummyStoreBackend(StoreBackendBase):
  895. """A dummy store backend that does nothing."""
  896. def _open_item(self, *args, **kwargs):
  897. """Open an item on store."""
  898. "Does nothing"
  899. def _item_exists(self, location):
  900. """Check if an item location exists."""
  901. "Does nothing"
  902. def _move_item(self, src, dst):
  903. """Move an item from src to dst in store."""
  904. "Does nothing"
  905. def create_location(self, location):
  906. """Create location on store."""
  907. "Does nothing"
  908. def exists(self, obj):
  909. """Check if an object exists in the store"""
  910. return False
  911. def clear_location(self, obj):
  912. """Clear object on store"""
  913. "Does nothing"
  914. def get_items(self):
  915. """Returns the whole list of items available in cache."""
  916. return []
  917. def configure(self, location, *args, **kwargs):
  918. """Configure the store"""
  919. "Does nothing"
  920. @parametrize("invalid_prefix", [None, dict(), list()])
  921. def test_register_invalid_store_backends_key(invalid_prefix):
  922. # verify the right exceptions are raised when passing a wrong backend key.
  923. with raises(ValueError) as excinfo:
  924. register_store_backend(invalid_prefix, None)
  925. excinfo.match(r"Store backend name should be a string*")
  926. def test_register_invalid_store_backends_object():
  927. # verify the right exceptions are raised when passing a wrong backend
  928. # object.
  929. with raises(ValueError) as excinfo:
  930. register_store_backend("fs", None)
  931. excinfo.match(r"Store backend should inherit StoreBackendBase*")
  932. def test_memory_default_store_backend():
  933. # test an unknown backend falls back into a FileSystemStoreBackend
  934. with raises(TypeError) as excinfo:
  935. Memory(location="/tmp/joblib", backend="unknown")
  936. excinfo.match(r"Unknown location*")
  937. def test_warning_on_unknown_location_type():
  938. class NonSupportedLocationClass:
  939. pass
  940. unsupported_location = NonSupportedLocationClass()
  941. with warns(UserWarning) as warninfo:
  942. _store_backend_factory("local", location=unsupported_location)
  943. expected_mesage = (
  944. "Instantiating a backend using a "
  945. "NonSupportedLocationClass as a location is not "
  946. "supported by joblib"
  947. )
  948. assert expected_mesage in str(warninfo[0].message)
  949. def test_instanciate_incomplete_store_backend():
  950. # Verify that registering an external incomplete store backend raises an
  951. # exception when one tries to instantiate it.
  952. backend_name = "isb"
  953. register_store_backend(backend_name, IncompleteStoreBackend)
  954. assert (backend_name, IncompleteStoreBackend) in _STORE_BACKENDS.items()
  955. with raises(TypeError) as excinfo:
  956. _store_backend_factory(backend_name, "fake_location")
  957. excinfo.match(
  958. r"Can't instantiate abstract class IncompleteStoreBackend "
  959. "(without an implementation for|with) abstract methods*"
  960. )
  961. def test_dummy_store_backend():
  962. # Verify that registering an external store backend works.
  963. backend_name = "dsb"
  964. register_store_backend(backend_name, DummyStoreBackend)
  965. assert (backend_name, DummyStoreBackend) in _STORE_BACKENDS.items()
  966. backend_obj = _store_backend_factory(backend_name, "dummy_location")
  967. assert isinstance(backend_obj, DummyStoreBackend)
  968. def test_instanciate_store_backend_with_pathlib_path():
  969. # Instantiate a FileSystemStoreBackend using a pathlib.Path object
  970. path = Path("some_folder")
  971. backend_obj = _store_backend_factory("local", path)
  972. try:
  973. assert backend_obj.location == "some_folder"
  974. finally: # remove cache folder after test
  975. shutil.rmtree("some_folder", ignore_errors=True)
  976. def test_filesystem_store_backend_repr(tmpdir):
  977. # Verify string representation of a filesystem store backend.
  978. repr_pattern = 'FileSystemStoreBackend(location="{location}")'
  979. backend = FileSystemStoreBackend()
  980. assert backend.location is None
  981. repr(backend) # Should not raise an exception
  982. assert str(backend) == repr_pattern.format(location=None)
  983. # backend location is passed explicitly via the configure method (called
  984. # by the internal _store_backend_factory function)
  985. backend.configure(tmpdir.strpath)
  986. assert str(backend) == repr_pattern.format(location=tmpdir.strpath)
  987. repr(backend) # Should not raise an exception
  988. def test_memory_objects_repr(tmpdir):
  989. # Verify printable reprs of MemorizedResult, MemorizedFunc and Memory.
  990. def my_func(a, b):
  991. return a + b
  992. memory = Memory(location=tmpdir.strpath, verbose=0)
  993. memorized_func = memory.cache(my_func)
  994. memorized_func_repr = "MemorizedFunc(func={func}, location={location})"
  995. assert str(memorized_func) == memorized_func_repr.format(
  996. func=my_func, location=memory.store_backend.location
  997. )
  998. memorized_result = memorized_func.call_and_shelve(42, 42)
  999. memorized_result_repr = (
  1000. 'MemorizedResult(location="{location}", func="{func}", args_id="{args_id}")'
  1001. )
  1002. assert str(memorized_result) == memorized_result_repr.format(
  1003. location=memory.store_backend.location,
  1004. func=memorized_result.func_id,
  1005. args_id=memorized_result.args_id,
  1006. )
  1007. assert str(memory) == "Memory(location={location})".format(
  1008. location=memory.store_backend.location
  1009. )
  1010. def test_memorized_result_pickle(tmpdir):
  1011. # Verify a MemoryResult object can be pickled/depickled. Non regression
  1012. # test introduced following issue
  1013. # https://github.com/joblib/joblib/issues/747
  1014. memory = Memory(location=tmpdir.strpath)
  1015. @memory.cache
  1016. def g(x):
  1017. return x**2
  1018. memorized_result = g.call_and_shelve(4)
  1019. memorized_result_pickle = pickle.dumps(memorized_result)
  1020. memorized_result_loads = pickle.loads(memorized_result_pickle)
  1021. assert (
  1022. memorized_result.store_backend.location
  1023. == memorized_result_loads.store_backend.location
  1024. )
  1025. assert memorized_result.func == memorized_result_loads.func
  1026. assert memorized_result.args_id == memorized_result_loads.args_id
  1027. assert str(memorized_result) == str(memorized_result_loads)
  1028. def compare(left, right, ignored_attrs=None):
  1029. if ignored_attrs is None:
  1030. ignored_attrs = []
  1031. left_vars = vars(left)
  1032. right_vars = vars(right)
  1033. assert set(left_vars.keys()) == set(right_vars.keys())
  1034. for attr in left_vars.keys():
  1035. if attr in ignored_attrs:
  1036. continue
  1037. assert left_vars[attr] == right_vars[attr]
  1038. @pytest.mark.parametrize(
  1039. "memory_kwargs",
  1040. [
  1041. {"compress": 3, "verbose": 2},
  1042. {"mmap_mode": "r", "verbose": 5, "backend_options": {"parameter": "unused"}},
  1043. ],
  1044. )
  1045. def test_memory_pickle_dump_load(tmpdir, memory_kwargs):
  1046. memory = Memory(location=tmpdir.strpath, **memory_kwargs)
  1047. memory_reloaded = pickle.loads(pickle.dumps(memory))
  1048. # Compare Memory instance before and after pickle roundtrip
  1049. compare(memory.store_backend, memory_reloaded.store_backend)
  1050. compare(
  1051. memory,
  1052. memory_reloaded,
  1053. ignored_attrs=set(["store_backend", "timestamp", "_func_code_id"]),
  1054. )
  1055. assert hash(memory) == hash(memory_reloaded)
  1056. func_cached = memory.cache(f)
  1057. func_cached_reloaded = pickle.loads(pickle.dumps(func_cached))
  1058. # Compare MemorizedFunc instance before/after pickle roundtrip
  1059. compare(func_cached.store_backend, func_cached_reloaded.store_backend)
  1060. compare(
  1061. func_cached,
  1062. func_cached_reloaded,
  1063. ignored_attrs=set(["store_backend", "timestamp", "_func_code_id"]),
  1064. )
  1065. assert hash(func_cached) == hash(func_cached_reloaded)
  1066. # Compare MemorizedResult instance before/after pickle roundtrip
  1067. memorized_result = func_cached.call_and_shelve(1)
  1068. memorized_result_reloaded = pickle.loads(pickle.dumps(memorized_result))
  1069. compare(memorized_result.store_backend, memorized_result_reloaded.store_backend)
  1070. compare(
  1071. memorized_result,
  1072. memorized_result_reloaded,
  1073. ignored_attrs=set(["store_backend", "timestamp", "_func_code_id"]),
  1074. )
  1075. assert hash(memorized_result) == hash(memorized_result_reloaded)
  1076. def test_info_log(tmpdir, caplog):
  1077. caplog.set_level(logging.INFO)
  1078. x = 3
  1079. memory = Memory(location=tmpdir.strpath, verbose=20)
  1080. @memory.cache
  1081. def f(x):
  1082. return x**2
  1083. _ = f(x)
  1084. assert "Querying" in caplog.text
  1085. caplog.clear()
  1086. memory = Memory(location=tmpdir.strpath, verbose=0)
  1087. @memory.cache
  1088. def f(x):
  1089. return x**2
  1090. _ = f(x)
  1091. assert "Querying" not in caplog.text
  1092. caplog.clear()
  1093. class TestCacheValidationCallback:
  1094. "Tests on parameter `cache_validation_callback`"
  1095. def foo(self, x, d, delay=None):
  1096. d["run"] = True
  1097. if delay is not None:
  1098. time.sleep(delay)
  1099. return x * 2
  1100. def test_invalid_cache_validation_callback(self, memory):
  1101. "Test invalid values for `cache_validation_callback"
  1102. match = "cache_validation_callback needs to be callable. Got True."
  1103. with pytest.raises(ValueError, match=match):
  1104. memory.cache(cache_validation_callback=True)
  1105. @pytest.mark.parametrize("consider_cache_valid", [True, False])
  1106. def test_constant_cache_validation_callback(self, memory, consider_cache_valid):
  1107. "Test expiry of old results"
  1108. f = memory.cache(
  1109. self.foo,
  1110. cache_validation_callback=lambda _: consider_cache_valid,
  1111. ignore=["d"],
  1112. )
  1113. d1, d2 = {"run": False}, {"run": False}
  1114. assert f(2, d1) == 4
  1115. assert f(2, d2) == 4
  1116. assert d1["run"]
  1117. assert d2["run"] != consider_cache_valid
  1118. def test_memory_only_cache_long_run(self, memory):
  1119. "Test cache validity based on run duration."
  1120. def cache_validation_callback(metadata):
  1121. duration = metadata["duration"]
  1122. if duration > 0.1:
  1123. return True
  1124. f = memory.cache(
  1125. self.foo, cache_validation_callback=cache_validation_callback, ignore=["d"]
  1126. )
  1127. # Short run are not cached
  1128. d1, d2 = {"run": False}, {"run": False}
  1129. assert f(2, d1, delay=0) == 4
  1130. assert f(2, d2, delay=0) == 4
  1131. assert d1["run"]
  1132. assert d2["run"]
  1133. # Longer run are cached
  1134. d1, d2 = {"run": False}, {"run": False}
  1135. assert f(2, d1, delay=0.2) == 4
  1136. assert f(2, d2, delay=0.2) == 4
  1137. assert d1["run"]
  1138. assert not d2["run"]
  1139. def test_memory_expires_after(self, memory):
  1140. "Test expiry of old cached results"
  1141. f = memory.cache(
  1142. self.foo, cache_validation_callback=expires_after(seconds=0.3), ignore=["d"]
  1143. )
  1144. d1, d2, d3 = {"run": False}, {"run": False}, {"run": False}
  1145. assert f(2, d1) == 4
  1146. assert f(2, d2) == 4
  1147. time.sleep(0.5)
  1148. assert f(2, d3) == 4
  1149. assert d1["run"]
  1150. assert not d2["run"]
  1151. assert d3["run"]
  1152. class TestMemorizedFunc:
  1153. "Tests for the MemorizedFunc and NotMemorizedFunc classes"
  1154. @staticmethod
  1155. def f(x, counter):
  1156. counter[x] = counter.get(x, 0) + 1
  1157. return counter[x]
  1158. def test_call_method_memorized(self, memory):
  1159. "Test calling the function"
  1160. f = memory.cache(self.f, ignore=["counter"])
  1161. counter = {}
  1162. assert f(2, counter) == 1
  1163. assert f(2, counter) == 1
  1164. x, meta = f.call(2, counter)
  1165. assert x == 2, "f has not been called properly"
  1166. assert isinstance(meta, dict), (
  1167. "Metadata are not returned by MemorizedFunc.call."
  1168. )
  1169. def test_call_method_not_memorized(self, memory):
  1170. "Test calling the function"
  1171. f = NotMemorizedFunc(self.f)
  1172. counter = {}
  1173. assert f(2, counter) == 1
  1174. assert f(2, counter) == 2
  1175. x, meta = f.call(2, counter)
  1176. assert x == 3, "f has not been called properly"
  1177. assert isinstance(meta, dict), (
  1178. "Metadata are not returned by MemorizedFunc.call."
  1179. )
  1180. class TestAutoGitignore:
  1181. "Tests for the MemorizedFunc and NotMemorizedFunc classes"
  1182. def test_memory_creates_gitignore(self, tmpdir):
  1183. """Test that using the memory object automatically creates a `.gitignore` file
  1184. within the new cache directory."""
  1185. location = Path(tmpdir.mkdir("test_cache_dir"))
  1186. mem = Memory(location)
  1187. costly_operation = mem.cache(id)
  1188. costly_operation(0)
  1189. gitignore_file = location / ".gitignore"
  1190. assert gitignore_file.exists()
  1191. assert gitignore_file.read_text() == "# Created by joblib automatically.\n*\n"
  1192. def test_memory_does_not_overwrite_existing_gitignore(self, tmpdir):
  1193. """Test that using the memory object does not overwrite an existing
  1194. `.gitignore` file within the cache directory."""
  1195. location = Path(tmpdir.mkdir("test_cache_dir"))
  1196. gitignore_file = location / ".gitignore"
  1197. existing_content = "# Existing .gitignore file!"
  1198. gitignore_file.write_text(existing_content)
  1199. # Cache a function and call it.
  1200. mem = Memory(location)
  1201. mem.cache(id)(0)
  1202. assert gitignore_file.exists()
  1203. assert gitignore_file.read_text() == existing_content