| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225 |
- """Test the numpy pickler as a replacement of the standard pickler."""
- import bz2
- import copy
- import gzip
- import io
- import mmap
- import os
- import pickle
- import random
- import re
- import socket
- import sys
- import warnings
- import zlib
- from contextlib import closing
- from pathlib import Path
- try:
- import lzma
- except ImportError:
- lzma = None
- import pytest
- # numpy_pickle is not a drop-in replacement of pickle, as it takes
- # filenames instead of open files as arguments.
- from joblib import numpy_pickle, register_compressor
- from joblib.compressor import (
- _COMPRESSORS,
- _LZ4_PREFIX,
- LZ4_NOT_INSTALLED_ERROR,
- BinaryZlibFile,
- CompressorWrapper,
- )
- from joblib.numpy_pickle_utils import (
- _IO_BUFFER_SIZE,
- _detect_compressor,
- _ensure_native_byte_order,
- _is_numpy_array_byte_order_mismatch,
- )
- from joblib.test import data
- from joblib.test.common import (
- memory_used,
- np,
- with_lz4,
- with_memory_profiler,
- with_numpy,
- without_lz4,
- )
- from joblib.testing import parametrize, raises, warns
- ###############################################################################
- # Define a list of standard types.
- # Borrowed from dill, initial author: Micheal McKerns:
- # http://dev.danse.us/trac/pathos/browser/dill/dill_test2.py
- typelist = []
- # testing types
- _none = None
- typelist.append(_none)
- _type = type
- typelist.append(_type)
- _bool = bool(1)
- typelist.append(_bool)
- _int = int(1)
- typelist.append(_int)
- _float = float(1)
- typelist.append(_float)
- _complex = complex(1)
- typelist.append(_complex)
- _string = str(1)
- typelist.append(_string)
- _tuple = ()
- typelist.append(_tuple)
- _list = []
- typelist.append(_list)
- _dict = {}
- typelist.append(_dict)
- _builtin = len
- typelist.append(_builtin)
- def _function(x):
- yield x
- class _class:
- def _method(self):
- pass
- class _newclass(object):
- def _method(self):
- pass
- typelist.append(_function)
- typelist.append(_class)
- typelist.append(_newclass) # <type 'type'>
- _instance = _class()
- typelist.append(_instance)
- _object = _newclass()
- typelist.append(_object) # <type 'class'>
- ###############################################################################
- # Tests
- @parametrize("compress", [0, 1])
- @parametrize("member", typelist)
- def test_standard_types(tmpdir, compress, member):
- # Test pickling and saving with standard types.
- filename = tmpdir.join("test.pkl").strpath
- numpy_pickle.dump(member, filename, compress=compress)
- _member = numpy_pickle.load(filename)
- # We compare the pickled instance to the reloaded one only if it
- # can be compared to a copied one
- if member == copy.deepcopy(member):
- assert member == _member
- def test_value_error():
- # Test inverting the input arguments to dump
- with raises(ValueError):
- numpy_pickle.dump("foo", dict())
- @parametrize("wrong_compress", [-1, 10, dict()])
- def test_compress_level_error(wrong_compress):
- # Verify that passing an invalid compress argument raises an error.
- exception_msg = 'Non valid compress level given: "{0}"'.format(wrong_compress)
- with raises(ValueError) as excinfo:
- numpy_pickle.dump("dummy", "foo", compress=wrong_compress)
- excinfo.match(exception_msg)
- @with_numpy
- @parametrize("compress", [False, True, 0, 3, "zlib"])
- def test_numpy_persistence(tmpdir, compress):
- filename = tmpdir.join("test.pkl").strpath
- rnd = np.random.RandomState(0)
- a = rnd.random_sample((10, 2))
- # We use 'a.T' to have a non C-contiguous array.
- for index, obj in enumerate(((a,), (a.T,), (a, a), [a, a, a])):
- filenames = numpy_pickle.dump(obj, filename, compress=compress)
- # All is cached in one file
- assert len(filenames) == 1
- # Check that only one file was created
- assert filenames[0] == filename
- # Check that this file does exist
- assert os.path.exists(filenames[0])
- # Unpickle the object
- obj_ = numpy_pickle.load(filename)
- # Check that the items are indeed arrays
- for item in obj_:
- assert isinstance(item, np.ndarray)
- # And finally, check that all the values are equal.
- np.testing.assert_array_equal(np.array(obj), np.array(obj_))
- # Now test with an array subclass
- obj = np.memmap(filename + "mmap", mode="w+", shape=4, dtype=np.float64)
- filenames = numpy_pickle.dump(obj, filename, compress=compress)
- # All is cached in one file
- assert len(filenames) == 1
- obj_ = numpy_pickle.load(filename)
- if type(obj) is not np.memmap and hasattr(obj, "__array_prepare__"):
- # We don't reconstruct memmaps
- assert isinstance(obj_, type(obj))
- np.testing.assert_array_equal(obj_, obj)
- # Test with an object containing multiple numpy arrays
- obj = ComplexTestObject()
- filenames = numpy_pickle.dump(obj, filename, compress=compress)
- # All is cached in one file
- assert len(filenames) == 1
- obj_loaded = numpy_pickle.load(filename)
- assert isinstance(obj_loaded, type(obj))
- np.testing.assert_array_equal(obj_loaded.array_float, obj.array_float)
- np.testing.assert_array_equal(obj_loaded.array_int, obj.array_int)
- np.testing.assert_array_equal(obj_loaded.array_obj, obj.array_obj)
- @with_numpy
- def test_numpy_persistence_bufferred_array_compression(tmpdir):
- big_array = np.ones((_IO_BUFFER_SIZE + 100), dtype=np.uint8)
- filename = tmpdir.join("test.pkl").strpath
- numpy_pickle.dump(big_array, filename, compress=True)
- arr_reloaded = numpy_pickle.load(filename)
- np.testing.assert_array_equal(big_array, arr_reloaded)
- @with_numpy
- def test_memmap_persistence(tmpdir):
- rnd = np.random.RandomState(0)
- a = rnd.random_sample(10)
- filename = tmpdir.join("test1.pkl").strpath
- numpy_pickle.dump(a, filename)
- b = numpy_pickle.load(filename, mmap_mode="r")
- assert isinstance(b, np.memmap)
- # Test with an object containing multiple numpy arrays
- filename = tmpdir.join("test2.pkl").strpath
- obj = ComplexTestObject()
- numpy_pickle.dump(obj, filename)
- obj_loaded = numpy_pickle.load(filename, mmap_mode="r")
- assert isinstance(obj_loaded, type(obj))
- assert isinstance(obj_loaded.array_float, np.memmap)
- assert not obj_loaded.array_float.flags.writeable
- assert isinstance(obj_loaded.array_int, np.memmap)
- assert not obj_loaded.array_int.flags.writeable
- # Memory map not allowed for numpy object arrays
- assert not isinstance(obj_loaded.array_obj, np.memmap)
- np.testing.assert_array_equal(obj_loaded.array_float, obj.array_float)
- np.testing.assert_array_equal(obj_loaded.array_int, obj.array_int)
- np.testing.assert_array_equal(obj_loaded.array_obj, obj.array_obj)
- # Test we can write in memmapped arrays
- obj_loaded = numpy_pickle.load(filename, mmap_mode="r+")
- assert obj_loaded.array_float.flags.writeable
- obj_loaded.array_float[0:10] = 10.0
- assert obj_loaded.array_int.flags.writeable
- obj_loaded.array_int[0:10] = 10
- obj_reloaded = numpy_pickle.load(filename, mmap_mode="r")
- np.testing.assert_array_equal(obj_reloaded.array_float, obj_loaded.array_float)
- np.testing.assert_array_equal(obj_reloaded.array_int, obj_loaded.array_int)
- # Test w+ mode is caught and the mode has switched to r+
- numpy_pickle.load(filename, mmap_mode="w+")
- assert obj_loaded.array_int.flags.writeable
- assert obj_loaded.array_int.mode == "r+"
- assert obj_loaded.array_float.flags.writeable
- assert obj_loaded.array_float.mode == "r+"
- @with_numpy
- def test_memmap_persistence_mixed_dtypes(tmpdir):
- # loading datastructures that have sub-arrays with dtype=object
- # should not prevent memmapping on fixed size dtype sub-arrays.
- rnd = np.random.RandomState(0)
- a = rnd.random_sample(10)
- b = np.array([1, "b"], dtype=object)
- construct = (a, b)
- filename = tmpdir.join("test.pkl").strpath
- numpy_pickle.dump(construct, filename)
- a_clone, b_clone = numpy_pickle.load(filename, mmap_mode="r")
- # the floating point array has been memory mapped
- assert isinstance(a_clone, np.memmap)
- # the object-dtype array has been loaded in memory
- assert not isinstance(b_clone, np.memmap)
- @with_numpy
- def test_masked_array_persistence(tmpdir):
- # The special-case picker fails, because saving masked_array
- # not implemented, but it just delegates to the standard pickler.
- rnd = np.random.RandomState(0)
- a = rnd.random_sample(10)
- a = np.ma.masked_greater(a, 0.5)
- filename = tmpdir.join("test.pkl").strpath
- numpy_pickle.dump(a, filename)
- b = numpy_pickle.load(filename, mmap_mode="r")
- assert isinstance(b, np.ma.masked_array)
- @with_numpy
- def test_compress_mmap_mode_warning(tmpdir):
- # Test the warning in case of compress + mmap_mode
- rnd = np.random.RandomState(0)
- obj = rnd.random_sample(10)
- this_filename = tmpdir.join("test.pkl").strpath
- numpy_pickle.dump(obj, this_filename, compress=1)
- with warns(UserWarning) as warninfo:
- reloaded_obj = numpy_pickle.load(this_filename, mmap_mode="r+")
- debug_msg = "\n".join([str(w) for w in warninfo])
- warninfo = [w.message for w in warninfo]
- assert not isinstance(reloaded_obj, np.memmap)
- np.testing.assert_array_equal(obj, reloaded_obj)
- assert len(warninfo) == 1, debug_msg
- assert (
- str(warninfo[0]) == 'mmap_mode "r+" is not compatible with compressed '
- f'file {this_filename}. "r+" flag will be ignored.'
- )
- @with_numpy
- @with_memory_profiler
- @parametrize("compress", [True, False])
- def test_memory_usage(tmpdir, compress):
- # Verify memory stays within expected bounds.
- filename = tmpdir.join("test.pkl").strpath
- small_array = np.ones((10, 10))
- big_array = np.ones(shape=100 * int(1e6), dtype=np.uint8)
- for obj in (small_array, big_array):
- size = obj.nbytes / 1e6
- obj_filename = filename + str(np.random.randint(0, 1000))
- mem_used = memory_used(numpy_pickle.dump, obj, obj_filename, compress=compress)
- # The memory used to dump the object shouldn't exceed the buffer
- # size used to write array chunks (16MB).
- write_buf_size = _IO_BUFFER_SIZE + 16 * 1024**2 / 1e6
- assert mem_used <= write_buf_size
- mem_used = memory_used(numpy_pickle.load, obj_filename)
- # memory used should be less than array size + buffer size used to
- # read the array chunk by chunk.
- read_buf_size = 32 + _IO_BUFFER_SIZE # MiB
- assert mem_used < size + read_buf_size
- @with_numpy
- def test_compressed_pickle_dump_and_load(tmpdir):
- expected_list = [
- np.arange(5, dtype=np.dtype("<i8")),
- np.arange(5, dtype=np.dtype(">i8")),
- np.arange(5, dtype=np.dtype("<f8")),
- np.arange(5, dtype=np.dtype(">f8")),
- np.array([1, "abc", {"a": 1, "b": 2}], dtype="O"),
- np.arange(256, dtype=np.uint8).tobytes(),
- "C'est l'\xe9t\xe9 !",
- ]
- fname = tmpdir.join("temp.pkl.gz").strpath
- dumped_filenames = numpy_pickle.dump(expected_list, fname, compress=1)
- assert len(dumped_filenames) == 1
- result_list = numpy_pickle.load(fname)
- for result, expected in zip(result_list, expected_list):
- if isinstance(expected, np.ndarray):
- expected = _ensure_native_byte_order(expected)
- assert result.dtype == expected.dtype
- np.testing.assert_equal(result, expected)
- else:
- assert result == expected
- @with_numpy
- def test_memmap_load(tmpdir):
- little_endian_dtype = np.dtype("<i8")
- big_endian_dtype = np.dtype(">i8")
- all_dtypes = (little_endian_dtype, big_endian_dtype)
- le_array = np.arange(5, dtype=little_endian_dtype)
- be_array = np.arange(5, dtype=big_endian_dtype)
- fname = tmpdir.join("temp.pkl").strpath
- numpy_pickle.dump([le_array, be_array], fname)
- le_array_native_load, be_array_native_load = numpy_pickle.load(
- fname, ensure_native_byte_order=True
- )
- assert le_array_native_load.dtype == be_array_native_load.dtype
- assert le_array_native_load.dtype in all_dtypes
- le_array_nonnative_load, be_array_nonnative_load = numpy_pickle.load(
- fname, ensure_native_byte_order=False
- )
- assert le_array_nonnative_load.dtype == le_array.dtype
- assert be_array_nonnative_load.dtype == be_array.dtype
- def test_invalid_parameters_raise():
- expected_msg = (
- "Native byte ordering can only be enforced if 'mmap_mode' parameter "
- "is set to None, but got 'mmap_mode=r+' instead."
- )
- with raises(ValueError, match=re.escape(expected_msg)):
- numpy_pickle.load(
- "/path/to/some/dump.pkl", ensure_native_byte_order=True, mmap_mode="r+"
- )
- def _check_pickle(filename, expected_list, mmap_mode=None):
- """Helper function to test joblib pickle content.
- Note: currently only pickles containing an iterable are supported
- by this function.
- """
- version_match = re.match(r".+py(\d)(\d).+", filename)
- py_version_used_for_writing = int(version_match.group(1))
- py_version_to_default_pickle_protocol = {2: 2, 3: 3}
- pickle_reading_protocol = py_version_to_default_pickle_protocol.get(3, 4)
- pickle_writing_protocol = py_version_to_default_pickle_protocol.get(
- py_version_used_for_writing, 4
- )
- if pickle_reading_protocol >= pickle_writing_protocol:
- try:
- with warnings.catch_warnings(record=True) as warninfo:
- warnings.simplefilter("always")
- result_list = numpy_pickle.load(filename, mmap_mode=mmap_mode)
- filename_base = os.path.basename(filename)
- expected_nb_deprecation_warnings = (
- 1 if ("_0.9" in filename_base or "_0.8.4" in filename_base) else 0
- )
- expected_nb_user_warnings = (
- 3
- if (re.search("_0.1.+.pkl$", filename_base) and mmap_mode is not None)
- else 0
- )
- expected_nb_warnings = (
- expected_nb_deprecation_warnings + expected_nb_user_warnings
- )
- assert len(warninfo) == expected_nb_warnings, (
- "Did not get the expected number of warnings. Expected "
- f"{expected_nb_warnings} but got warnings: "
- f"{[w.message for w in warninfo]}"
- )
- deprecation_warnings = [
- w for w in warninfo if issubclass(w.category, DeprecationWarning)
- ]
- user_warnings = [w for w in warninfo if issubclass(w.category, UserWarning)]
- for w in deprecation_warnings:
- assert (
- str(w.message)
- == "The file '{0}' has been generated with a joblib "
- "version less than 0.10. Please regenerate this "
- "pickle file.".format(filename)
- )
- for w in user_warnings:
- escaped_filename = re.escape(filename)
- assert re.search(
- f"memmapped.+{escaped_filename}.+segmentation fault", str(w.message)
- )
- for result, expected in zip(result_list, expected_list):
- if isinstance(expected, np.ndarray):
- expected = _ensure_native_byte_order(expected)
- assert result.dtype == expected.dtype
- np.testing.assert_equal(result, expected)
- else:
- assert result == expected
- except Exception as exc:
- # When trying to read with python 3 a pickle generated
- # with python 2 we expect a user-friendly error
- if py_version_used_for_writing == 2:
- assert isinstance(exc, ValueError)
- message = (
- "You may be trying to read with "
- "python 3 a joblib pickle generated with python 2."
- )
- assert message in str(exc)
- elif filename.endswith(".lz4") and with_lz4.args[0]:
- assert isinstance(exc, ValueError)
- assert LZ4_NOT_INSTALLED_ERROR in str(exc)
- else:
- raise
- else:
- # Pickle protocol used for writing is too high. We expect a
- # "unsupported pickle protocol" error message
- try:
- numpy_pickle.load(filename)
- raise AssertionError(
- "Numpy pickle loading should have raised a ValueError exception"
- )
- except ValueError as e:
- message = "unsupported pickle protocol: {0}".format(pickle_writing_protocol)
- assert message in str(e.args)
- @with_numpy
- def test_joblib_pickle_across_python_versions():
- # We need to be specific about dtypes in particular endianness
- # because the pickles can be generated on one architecture and
- # the tests run on another one. See
- # https://github.com/joblib/joblib/issues/279.
- expected_list = [
- np.arange(5, dtype=np.dtype("<i8")),
- np.arange(5, dtype=np.dtype("<f8")),
- np.array([1, "abc", {"a": 1, "b": 2}], dtype="O"),
- np.arange(256, dtype=np.uint8).tobytes(),
- # np.matrix is a subclass of np.ndarray, here we want
- # to verify this type of object is correctly unpickled
- # among versions.
- np.matrix([0, 1, 2], dtype=np.dtype("<i8")),
- "C'est l'\xe9t\xe9 !",
- ]
- # Testing all the compressed and non compressed
- # pickles in joblib/test/data. These pickles were generated by
- # the joblib/test/data/create_numpy_pickle.py script for the
- # relevant python, joblib and numpy versions.
- test_data_dir = os.path.dirname(os.path.abspath(data.__file__))
- pickle_extensions = (".pkl", ".gz", ".gzip", ".bz2", "lz4")
- if lzma is not None:
- pickle_extensions += (".xz", ".lzma")
- pickle_filenames = [
- os.path.join(test_data_dir, fn)
- for fn in os.listdir(test_data_dir)
- if any(fn.endswith(ext) for ext in pickle_extensions)
- ]
- for fname in pickle_filenames:
- _check_pickle(fname, expected_list)
- @with_numpy
- def test_joblib_pickle_across_python_versions_with_mmap():
- expected_list = [
- np.arange(5, dtype=np.dtype("<i8")),
- np.arange(5, dtype=np.dtype("<f8")),
- np.array([1, "abc", {"a": 1, "b": 2}], dtype="O"),
- np.arange(256, dtype=np.uint8).tobytes(),
- # np.matrix is a subclass of np.ndarray, here we want
- # to verify this type of object is correctly unpickled
- # among versions.
- np.matrix([0, 1, 2], dtype=np.dtype("<i8")),
- "C'est l'\xe9t\xe9 !",
- ]
- test_data_dir = os.path.dirname(os.path.abspath(data.__file__))
- pickle_filenames = [
- os.path.join(test_data_dir, fn)
- for fn in os.listdir(test_data_dir)
- if fn.endswith(".pkl")
- ]
- for fname in pickle_filenames:
- _check_pickle(fname, expected_list, mmap_mode="r")
- @with_numpy
- def test_numpy_array_byte_order_mismatch_detection():
- # List of numpy arrays with big endian byteorder.
- be_arrays = [
- np.array([(1, 2.0), (3, 4.0)], dtype=[("", ">i8"), ("", ">f8")]),
- np.arange(3, dtype=np.dtype(">i8")),
- np.arange(3, dtype=np.dtype(">f8")),
- ]
- # Verify the byteorder mismatch is correctly detected.
- for array in be_arrays:
- if sys.byteorder == "big":
- assert not _is_numpy_array_byte_order_mismatch(array)
- else:
- assert _is_numpy_array_byte_order_mismatch(array)
- converted = _ensure_native_byte_order(array)
- if converted.dtype.fields:
- for f in converted.dtype.fields.values():
- f[0].byteorder == "="
- else:
- assert converted.dtype.byteorder == "="
- # List of numpy arrays with little endian byteorder.
- le_arrays = [
- np.array([(1, 2.0), (3, 4.0)], dtype=[("", "<i8"), ("", "<f8")]),
- np.arange(3, dtype=np.dtype("<i8")),
- np.arange(3, dtype=np.dtype("<f8")),
- ]
- # Verify the byteorder mismatch is correctly detected.
- for array in le_arrays:
- if sys.byteorder == "little":
- assert not _is_numpy_array_byte_order_mismatch(array)
- else:
- assert _is_numpy_array_byte_order_mismatch(array)
- converted = _ensure_native_byte_order(array)
- if converted.dtype.fields:
- for f in converted.dtype.fields.values():
- f[0].byteorder == "="
- else:
- assert converted.dtype.byteorder == "="
- @parametrize("compress_tuple", [("zlib", 3), ("gzip", 3)])
- def test_compress_tuple_argument(tmpdir, compress_tuple):
- # Verify the tuple is correctly taken into account.
- filename = tmpdir.join("test.pkl").strpath
- numpy_pickle.dump("dummy", filename, compress=compress_tuple)
- # Verify the file contains the right magic number
- with open(filename, "rb") as f:
- assert _detect_compressor(f) == compress_tuple[0]
- @parametrize(
- "compress_tuple,message",
- [
- (
- ("zlib", 3, "extra"), # wrong compress tuple
- "Compress argument tuple should contain exactly 2 elements",
- ),
- (
- ("wrong", 3), # wrong compress method
- 'Non valid compression method given: "{}"'.format("wrong"),
- ),
- (
- ("zlib", "wrong"), # wrong compress level
- 'Non valid compress level given: "{}"'.format("wrong"),
- ),
- ],
- )
- def test_compress_tuple_argument_exception(tmpdir, compress_tuple, message):
- filename = tmpdir.join("test.pkl").strpath
- # Verify setting a wrong compress tuple raises a ValueError.
- with raises(ValueError) as excinfo:
- numpy_pickle.dump("dummy", filename, compress=compress_tuple)
- excinfo.match(message)
- @parametrize("compress_string", ["zlib", "gzip"])
- def test_compress_string_argument(tmpdir, compress_string):
- # Verify the string is correctly taken into account.
- filename = tmpdir.join("test.pkl").strpath
- numpy_pickle.dump("dummy", filename, compress=compress_string)
- # Verify the file contains the right magic number
- with open(filename, "rb") as f:
- assert _detect_compressor(f) == compress_string
- @with_numpy
- @parametrize("compress", [1, 3, 6])
- @parametrize("cmethod", _COMPRESSORS)
- def test_joblib_compression_formats(tmpdir, compress, cmethod):
- filename = tmpdir.join("test.pkl").strpath
- objects = (
- np.ones(shape=(100, 100), dtype="f8"),
- range(10),
- {"a": 1, 2: "b"},
- [],
- (),
- {},
- 0,
- 1.0,
- )
- if cmethod in ("lzma", "xz") and lzma is None:
- pytest.skip("lzma is support not available")
- elif cmethod == "lz4" and with_lz4.args[0]:
- # Skip the test if lz4 is not installed. We here use the with_lz4
- # skipif fixture whose argument is True when lz4 is not installed
- pytest.skip("lz4 is not installed.")
- dump_filename = filename + "." + cmethod
- for obj in objects:
- numpy_pickle.dump(obj, dump_filename, compress=(cmethod, compress))
- # Verify the file contains the right magic number
- with open(dump_filename, "rb") as f:
- assert _detect_compressor(f) == cmethod
- # Verify the reloaded object is correct
- obj_reloaded = numpy_pickle.load(dump_filename)
- assert isinstance(obj_reloaded, type(obj))
- if isinstance(obj, np.ndarray):
- np.testing.assert_array_equal(obj_reloaded, obj)
- else:
- assert obj_reloaded == obj
- def _gzip_file_decompress(source_filename, target_filename):
- """Decompress a gzip file."""
- with closing(gzip.GzipFile(source_filename, "rb")) as fo:
- buf = fo.read()
- with open(target_filename, "wb") as fo:
- fo.write(buf)
- def _zlib_file_decompress(source_filename, target_filename):
- """Decompress a zlib file."""
- with open(source_filename, "rb") as fo:
- buf = zlib.decompress(fo.read())
- with open(target_filename, "wb") as fo:
- fo.write(buf)
- @parametrize(
- "extension,decompress",
- [(".z", _zlib_file_decompress), (".gz", _gzip_file_decompress)],
- )
- def test_load_externally_decompressed_files(tmpdir, extension, decompress):
- # Test that BinaryZlibFile generates valid gzip and zlib compressed files.
- obj = "a string to persist"
- filename_raw = tmpdir.join("test.pkl").strpath
- filename_compressed = filename_raw + extension
- # Use automatic extension detection to compress with the right method.
- numpy_pickle.dump(obj, filename_compressed)
- # Decompress with the corresponding method
- decompress(filename_compressed, filename_raw)
- # Test that the uncompressed pickle can be loaded and
- # that the result is correct.
- obj_reloaded = numpy_pickle.load(filename_raw)
- assert obj == obj_reloaded
- @parametrize(
- "extension,cmethod",
- # valid compressor extensions
- [
- (".z", "zlib"),
- (".gz", "gzip"),
- (".bz2", "bz2"),
- (".lzma", "lzma"),
- (".xz", "xz"),
- # invalid compressor extensions
- (".pkl", "not-compressed"),
- ("", "not-compressed"),
- ],
- )
- def test_compression_using_file_extension(tmpdir, extension, cmethod):
- if cmethod in ("lzma", "xz") and lzma is None:
- pytest.skip("lzma is missing")
- # test that compression method corresponds to the given filename extension.
- filename = tmpdir.join("test.pkl").strpath
- obj = "object to dump"
- dump_fname = filename + extension
- numpy_pickle.dump(obj, dump_fname)
- # Verify the file contains the right magic number
- with open(dump_fname, "rb") as f:
- assert _detect_compressor(f) == cmethod
- # Verify the reloaded object is correct
- obj_reloaded = numpy_pickle.load(dump_fname)
- assert isinstance(obj_reloaded, type(obj))
- assert obj_reloaded == obj
- @with_numpy
- def test_file_handle_persistence(tmpdir):
- objs = [np.random.random((10, 10)), "some data"]
- fobjs = [bz2.BZ2File, gzip.GzipFile]
- if lzma is not None:
- fobjs += [lzma.LZMAFile]
- filename = tmpdir.join("test.pkl").strpath
- for obj in objs:
- for fobj in fobjs:
- with fobj(filename, "wb") as f:
- numpy_pickle.dump(obj, f)
- # using the same decompressor prevents from internally
- # decompress again.
- with fobj(filename, "rb") as f:
- obj_reloaded = numpy_pickle.load(f)
- # when needed, the correct decompressor should be used when
- # passing a raw file handle.
- with open(filename, "rb") as f:
- obj_reloaded_2 = numpy_pickle.load(f)
- if isinstance(obj, np.ndarray):
- np.testing.assert_array_equal(obj_reloaded, obj)
- np.testing.assert_array_equal(obj_reloaded_2, obj)
- else:
- assert obj_reloaded == obj
- assert obj_reloaded_2 == obj
- @with_numpy
- def test_in_memory_persistence():
- objs = [np.random.random((10, 10)), "some data"]
- for obj in objs:
- f = io.BytesIO()
- numpy_pickle.dump(obj, f)
- obj_reloaded = numpy_pickle.load(f)
- if isinstance(obj, np.ndarray):
- np.testing.assert_array_equal(obj_reloaded, obj)
- else:
- assert obj_reloaded == obj
- @with_numpy
- def test_file_handle_persistence_mmap(tmpdir):
- obj = np.random.random((10, 10))
- filename = tmpdir.join("test.pkl").strpath
- with open(filename, "wb") as f:
- numpy_pickle.dump(obj, f)
- with open(filename, "rb") as f:
- obj_reloaded = numpy_pickle.load(f, mmap_mode="r+")
- np.testing.assert_array_equal(obj_reloaded, obj)
- @with_numpy
- def test_file_handle_persistence_compressed_mmap(tmpdir):
- obj = np.random.random((10, 10))
- filename = tmpdir.join("test.pkl").strpath
- with open(filename, "wb") as f:
- numpy_pickle.dump(obj, f, compress=("gzip", 3))
- with closing(gzip.GzipFile(filename, "rb")) as f:
- with warns(UserWarning) as warninfo:
- numpy_pickle.load(f, mmap_mode="r+")
- assert len(warninfo) == 1
- assert (
- str(warninfo[0].message)
- == '"%(fileobj)r" is not a raw file, mmap_mode "%(mmap_mode)s" '
- "flag will be ignored." % {"fileobj": f, "mmap_mode": "r+"}
- )
- @with_numpy
- def test_file_handle_persistence_in_memory_mmap():
- obj = np.random.random((10, 10))
- buf = io.BytesIO()
- numpy_pickle.dump(obj, buf)
- with warns(UserWarning) as warninfo:
- numpy_pickle.load(buf, mmap_mode="r+")
- assert len(warninfo) == 1
- assert (
- str(warninfo[0].message)
- == "In memory persistence is not compatible with mmap_mode "
- '"%(mmap_mode)s" flag passed. mmap_mode option will be '
- "ignored." % {"mmap_mode": "r+"}
- )
- @parametrize(
- "data",
- [
- b"a little data as bytes.",
- # More bytes
- 10000 * "{}".format(random.randint(0, 1000) * 1000).encode("latin-1"),
- ],
- ids=["a little data as bytes.", "a large data as bytes."],
- )
- @parametrize("compress_level", [1, 3, 9])
- def test_binary_zlibfile(tmpdir, data, compress_level):
- filename = tmpdir.join("test.pkl").strpath
- # Regular cases
- with open(filename, "wb") as f:
- with BinaryZlibFile(f, "wb", compresslevel=compress_level) as fz:
- assert fz.writable()
- fz.write(data)
- assert fz.fileno() == f.fileno()
- with raises(io.UnsupportedOperation):
- fz._check_can_read()
- with raises(io.UnsupportedOperation):
- fz._check_can_seek()
- assert fz.closed
- with raises(ValueError):
- fz._check_not_closed()
- with open(filename, "rb") as f:
- with BinaryZlibFile(f) as fz:
- assert fz.readable()
- assert fz.seekable()
- assert fz.fileno() == f.fileno()
- assert fz.read() == data
- with raises(io.UnsupportedOperation):
- fz._check_can_write()
- assert fz.seekable()
- fz.seek(0)
- assert fz.tell() == 0
- assert fz.closed
- # Test with a filename as input
- with BinaryZlibFile(filename, "wb", compresslevel=compress_level) as fz:
- assert fz.writable()
- fz.write(data)
- with BinaryZlibFile(filename, "rb") as fz:
- assert fz.read() == data
- assert fz.seekable()
- # Test without context manager
- fz = BinaryZlibFile(filename, "wb", compresslevel=compress_level)
- assert fz.writable()
- fz.write(data)
- fz.close()
- fz = BinaryZlibFile(filename, "rb")
- assert fz.read() == data
- fz.close()
- @parametrize("bad_value", [-1, 10, 15, "a", (), {}])
- def test_binary_zlibfile_bad_compression_levels(tmpdir, bad_value):
- filename = tmpdir.join("test.pkl").strpath
- with raises(ValueError) as excinfo:
- BinaryZlibFile(filename, "wb", compresslevel=bad_value)
- pattern = re.escape(
- "'compresslevel' must be an integer between 1 and 9. "
- "You provided 'compresslevel={}'".format(bad_value)
- )
- excinfo.match(pattern)
- @parametrize("bad_mode", ["a", "x", "r", "w", 1, 2])
- def test_binary_zlibfile_invalid_modes(tmpdir, bad_mode):
- filename = tmpdir.join("test.pkl").strpath
- with raises(ValueError) as excinfo:
- BinaryZlibFile(filename, bad_mode)
- excinfo.match("Invalid mode")
- @parametrize("bad_file", [1, (), {}])
- def test_binary_zlibfile_invalid_filename_type(bad_file):
- with raises(TypeError) as excinfo:
- BinaryZlibFile(bad_file, "rb")
- excinfo.match("filename must be a str or bytes object, or a file")
- ###############################################################################
- # Test dumping array subclasses
- if np is not None:
- class SubArray(np.ndarray):
- def __reduce__(self):
- return _load_sub_array, (np.asarray(self),)
- def _load_sub_array(arr):
- d = SubArray(arr.shape)
- d[:] = arr
- return d
- class ComplexTestObject:
- """A complex object containing numpy arrays as attributes."""
- def __init__(self):
- self.array_float = np.arange(100, dtype="float64")
- self.array_int = np.ones(100, dtype="int32")
- self.array_obj = np.array(["a", 10, 20.0], dtype="object")
- @with_numpy
- def test_numpy_subclass(tmpdir):
- filename = tmpdir.join("test.pkl").strpath
- a = SubArray((10,))
- numpy_pickle.dump(a, filename)
- c = numpy_pickle.load(filename)
- assert isinstance(c, SubArray)
- np.testing.assert_array_equal(c, a)
- def test_pathlib(tmpdir):
- filename = tmpdir.join("test.pkl").strpath
- value = 123
- numpy_pickle.dump(value, Path(filename))
- assert numpy_pickle.load(filename) == value
- numpy_pickle.dump(value, filename)
- assert numpy_pickle.load(Path(filename)) == value
- @with_numpy
- def test_non_contiguous_array_pickling(tmpdir):
- filename = tmpdir.join("test.pkl").strpath
- for array in [ # Array that triggers a contiguousness issue with nditer,
- # see https://github.com/joblib/joblib/pull/352 and see
- # https://github.com/joblib/joblib/pull/353
- np.asfortranarray([[1, 2], [3, 4]])[1:],
- # Non contiguous array with works fine with nditer
- np.ones((10, 50, 20), order="F")[:, :1, :],
- ]:
- assert not array.flags.c_contiguous
- assert not array.flags.f_contiguous
- numpy_pickle.dump(array, filename)
- array_reloaded = numpy_pickle.load(filename)
- np.testing.assert_array_equal(array_reloaded, array)
- @with_numpy
- def test_pickle_highest_protocol(tmpdir):
- # ensure persistence of a numpy array is valid even when using
- # the pickle HIGHEST_PROTOCOL.
- # see https://github.com/joblib/joblib/issues/362
- filename = tmpdir.join("test.pkl").strpath
- test_array = np.zeros(10)
- numpy_pickle.dump(test_array, filename, protocol=pickle.HIGHEST_PROTOCOL)
- array_reloaded = numpy_pickle.load(filename)
- np.testing.assert_array_equal(array_reloaded, test_array)
- @with_numpy
- def test_pickle_in_socket():
- # test that joblib can pickle in sockets
- test_array = np.arange(10)
- _ADDR = ("localhost", 12345)
- listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- listener.bind(_ADDR)
- listener.listen(1)
- with socket.create_connection(_ADDR) as client:
- server, client_addr = listener.accept()
- with server.makefile("wb") as sf:
- numpy_pickle.dump(test_array, sf)
- with client.makefile("rb") as cf:
- array_reloaded = numpy_pickle.load(cf)
- np.testing.assert_array_equal(array_reloaded, test_array)
- # Check that a byte-aligned numpy array written in a file can be send
- # over a socket and then read on the other side
- bytes_to_send = io.BytesIO()
- numpy_pickle.dump(test_array, bytes_to_send)
- server.send(bytes_to_send.getvalue())
- with client.makefile("rb") as cf:
- array_reloaded = numpy_pickle.load(cf)
- np.testing.assert_array_equal(array_reloaded, test_array)
- @with_numpy
- def test_load_memmap_with_big_offset(tmpdir):
- # Test that numpy memmap offset is set correctly if greater than
- # mmap.ALLOCATIONGRANULARITY, see
- # https://github.com/joblib/joblib/issues/451 and
- # https://github.com/numpy/numpy/pull/8443 for more details.
- fname = tmpdir.join("test.mmap").strpath
- size = mmap.ALLOCATIONGRANULARITY
- obj = [np.zeros(size, dtype="uint8"), np.ones(size, dtype="uint8")]
- numpy_pickle.dump(obj, fname)
- memmaps = numpy_pickle.load(fname, mmap_mode="r")
- assert isinstance(memmaps[1], np.memmap)
- assert memmaps[1].offset > size
- np.testing.assert_array_equal(obj, memmaps)
- def test_register_compressor(tmpdir):
- # Check that registering compressor file works.
- compressor_name = "test-name"
- compressor_prefix = "test-prefix"
- class BinaryCompressorTestFile(io.BufferedIOBase):
- pass
- class BinaryCompressorTestWrapper(CompressorWrapper):
- def __init__(self):
- CompressorWrapper.__init__(
- self, obj=BinaryCompressorTestFile, prefix=compressor_prefix
- )
- register_compressor(compressor_name, BinaryCompressorTestWrapper())
- assert _COMPRESSORS[compressor_name].fileobj_factory == BinaryCompressorTestFile
- assert _COMPRESSORS[compressor_name].prefix == compressor_prefix
- # Remove this dummy compressor file from extra compressors because other
- # tests might fail because of this
- _COMPRESSORS.pop(compressor_name)
- @parametrize("invalid_name", [1, (), {}])
- def test_register_compressor_invalid_name(invalid_name):
- # Test that registering an invalid compressor name is not allowed.
- with raises(ValueError) as excinfo:
- register_compressor(invalid_name, None)
- excinfo.match("Compressor name should be a string")
- def test_register_compressor_invalid_fileobj():
- # Test that registering an invalid file object is not allowed.
- class InvalidFileObject:
- pass
- class InvalidFileObjectWrapper(CompressorWrapper):
- def __init__(self):
- CompressorWrapper.__init__(self, obj=InvalidFileObject, prefix=b"prefix")
- with raises(ValueError) as excinfo:
- register_compressor("invalid", InvalidFileObjectWrapper())
- excinfo.match(
- "Compressor 'fileobj_factory' attribute should implement "
- "the file object interface"
- )
- class AnotherZlibCompressorWrapper(CompressorWrapper):
- def __init__(self):
- CompressorWrapper.__init__(self, obj=BinaryZlibFile, prefix=b"prefix")
- class StandardLibGzipCompressorWrapper(CompressorWrapper):
- def __init__(self):
- CompressorWrapper.__init__(self, obj=gzip.GzipFile, prefix=b"prefix")
- def test_register_compressor_already_registered():
- # Test registration of existing compressor files.
- compressor_name = "test-name"
- # register a test compressor
- register_compressor(compressor_name, AnotherZlibCompressorWrapper())
- with raises(ValueError) as excinfo:
- register_compressor(compressor_name, StandardLibGzipCompressorWrapper())
- excinfo.match("Compressor '{}' already registered.".format(compressor_name))
- register_compressor(compressor_name, StandardLibGzipCompressorWrapper(), force=True)
- assert compressor_name in _COMPRESSORS
- assert _COMPRESSORS[compressor_name].fileobj_factory == gzip.GzipFile
- # Remove this dummy compressor file from extra compressors because other
- # tests might fail because of this
- _COMPRESSORS.pop(compressor_name)
- @with_lz4
- def test_lz4_compression(tmpdir):
- # Check that lz4 can be used when dependency is available.
- import lz4.frame
- compressor = "lz4"
- assert compressor in _COMPRESSORS
- assert _COMPRESSORS[compressor].fileobj_factory == lz4.frame.LZ4FrameFile
- fname = tmpdir.join("test.pkl").strpath
- data = "test data"
- numpy_pickle.dump(data, fname, compress=compressor)
- with open(fname, "rb") as f:
- assert f.read(len(_LZ4_PREFIX)) == _LZ4_PREFIX
- assert numpy_pickle.load(fname) == data
- # Test that LZ4 is applied based on file extension
- numpy_pickle.dump(data, fname + ".lz4")
- with open(fname, "rb") as f:
- assert f.read(len(_LZ4_PREFIX)) == _LZ4_PREFIX
- assert numpy_pickle.load(fname) == data
- @without_lz4
- def test_lz4_compression_without_lz4(tmpdir):
- # Check that lz4 cannot be used when dependency is not available.
- fname = tmpdir.join("test.nolz4").strpath
- data = "test data"
- msg = LZ4_NOT_INSTALLED_ERROR
- with raises(ValueError) as excinfo:
- numpy_pickle.dump(data, fname, compress="lz4")
- excinfo.match(msg)
- with raises(ValueError) as excinfo:
- numpy_pickle.dump(data, fname + ".lz4")
- excinfo.match(msg)
- protocols = [pickle.DEFAULT_PROTOCOL]
- if pickle.HIGHEST_PROTOCOL != pickle.DEFAULT_PROTOCOL:
- protocols.append(pickle.HIGHEST_PROTOCOL)
- @with_numpy
- @parametrize("protocol", protocols)
- def test_memmap_alignment_padding(tmpdir, protocol):
- # Test that memmaped arrays returned by numpy.load are correctly aligned
- fname = tmpdir.join("test.mmap").strpath
- a = np.random.randn(2)
- numpy_pickle.dump(a, fname, protocol=protocol)
- memmap = numpy_pickle.load(fname, mmap_mode="r")
- assert isinstance(memmap, np.memmap)
- np.testing.assert_array_equal(a, memmap)
- assert memmap.ctypes.data % numpy_pickle.NUMPY_ARRAY_ALIGNMENT_BYTES == 0
- assert memmap.flags.aligned
- array_list = [
- np.random.randn(2),
- np.random.randn(2),
- np.random.randn(2),
- np.random.randn(2),
- ]
- # On Windows OSError 22 if reusing the same path for memmap ...
- fname = tmpdir.join("test1.mmap").strpath
- numpy_pickle.dump(array_list, fname, protocol=protocol)
- l_reloaded = numpy_pickle.load(fname, mmap_mode="r")
- for idx, memmap in enumerate(l_reloaded):
- assert isinstance(memmap, np.memmap)
- np.testing.assert_array_equal(array_list[idx], memmap)
- assert memmap.ctypes.data % numpy_pickle.NUMPY_ARRAY_ALIGNMENT_BYTES == 0
- assert memmap.flags.aligned
- array_dict = {
- "a0": np.arange(2, dtype=np.uint8),
- "a1": np.arange(3, dtype=np.uint8),
- "a2": np.arange(5, dtype=np.uint8),
- "a3": np.arange(7, dtype=np.uint8),
- "a4": np.arange(11, dtype=np.uint8),
- "a5": np.arange(13, dtype=np.uint8),
- "a6": np.arange(17, dtype=np.uint8),
- "a7": np.arange(19, dtype=np.uint8),
- "a8": np.arange(23, dtype=np.uint8),
- }
- # On Windows OSError 22 if reusing the same path for memmap ...
- fname = tmpdir.join("test2.mmap").strpath
- numpy_pickle.dump(array_dict, fname, protocol=protocol)
- d_reloaded = numpy_pickle.load(fname, mmap_mode="r")
- for key, memmap in d_reloaded.items():
- assert isinstance(memmap, np.memmap)
- np.testing.assert_array_equal(array_dict[key], memmap)
- assert memmap.ctypes.data % numpy_pickle.NUMPY_ARRAY_ALIGNMENT_BYTES == 0
- assert memmap.flags.aligned
|