create_numpy_pickle.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. """
  2. This script is used to generate test data for joblib/test/test_numpy_pickle.py
  3. """
  4. import re
  5. import sys
  6. # pytest needs to be able to import this module even when numpy is
  7. # not installed
  8. try:
  9. import numpy as np
  10. except ImportError:
  11. np = None
  12. import joblib
  13. def get_joblib_version(joblib_version=joblib.__version__):
  14. """Normalize joblib version by removing suffix.
  15. >>> get_joblib_version('0.8.4')
  16. '0.8.4'
  17. >>> get_joblib_version('0.8.4b1')
  18. '0.8.4'
  19. >>> get_joblib_version('0.9.dev0')
  20. '0.9'
  21. """
  22. matches = [re.match(r"(\d+).*", each) for each in joblib_version.split(".")]
  23. return ".".join([m.group(1) for m in matches if m is not None])
  24. def write_test_pickle(to_pickle, args):
  25. kwargs = {}
  26. compress = args.compress
  27. method = args.method
  28. joblib_version = get_joblib_version()
  29. py_version = "{0[0]}{0[1]}".format(sys.version_info)
  30. numpy_version = "".join(np.__version__.split(".")[:2])
  31. # The game here is to generate the right filename according to the options.
  32. body = "_compressed" if (compress and method == "zlib") else ""
  33. if compress:
  34. if method == "zlib":
  35. kwargs["compress"] = True
  36. extension = ".gz"
  37. else:
  38. kwargs["compress"] = (method, 3)
  39. extension = ".pkl.{}".format(method)
  40. if args.cache_size:
  41. kwargs["cache_size"] = 0
  42. body += "_cache_size"
  43. else:
  44. extension = ".pkl"
  45. pickle_filename = "joblib_{}{}_pickle_py{}_np{}{}".format(
  46. joblib_version, body, py_version, numpy_version, extension
  47. )
  48. try:
  49. joblib.dump(to_pickle, pickle_filename, **kwargs)
  50. except Exception as e:
  51. # With old python version (=< 3.3.), we can arrive there when
  52. # dumping compressed pickle with LzmaFile.
  53. print(
  54. "Error: cannot generate file '{}' with arguments '{}'. "
  55. "Error was: {}".format(pickle_filename, kwargs, e)
  56. )
  57. else:
  58. print("File '{}' generated successfully.".format(pickle_filename))
  59. if __name__ == "__main__":
  60. import argparse
  61. parser = argparse.ArgumentParser(description="Joblib pickle data generator.")
  62. parser.add_argument(
  63. "--cache_size",
  64. action="store_true",
  65. help="Force creation of companion numpy files for pickled arrays.",
  66. )
  67. parser.add_argument(
  68. "--compress", action="store_true", help="Generate compress pickles."
  69. )
  70. parser.add_argument(
  71. "--method",
  72. type=str,
  73. default="zlib",
  74. choices=["zlib", "gzip", "bz2", "xz", "lzma", "lz4"],
  75. help="Set compression method.",
  76. )
  77. # We need to be specific about dtypes in particular endianness
  78. # because the pickles can be generated on one architecture and
  79. # the tests run on another one. See
  80. # https://github.com/joblib/joblib/issues/279.
  81. to_pickle = [
  82. np.arange(5, dtype=np.dtype("<i8")),
  83. np.arange(5, dtype=np.dtype("<f8")),
  84. np.array([1, "abc", {"a": 1, "b": 2}], dtype="O"),
  85. # all possible bytes as a byte string
  86. np.arange(256, dtype=np.uint8).tobytes(),
  87. np.matrix([0, 1, 2], dtype=np.dtype("<i8")),
  88. # unicode string with non-ascii chars
  89. "C'est l'\xe9t\xe9 !",
  90. ]
  91. write_test_pickle(to_pickle, parser.parse_args())