numpy_pickle_utils.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. """Utilities for fast persistence of big data, with optional compression."""
  2. # Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
  3. # Copyright (c) 2009 Gael Varoquaux
  4. # License: BSD Style, 3 clauses.
  5. import contextlib
  6. import io
  7. import pickle
  8. import sys
  9. import warnings
  10. from .compressor import _COMPRESSORS, _ZFILE_PREFIX
  11. try:
  12. import numpy as np
  13. except ImportError:
  14. np = None
  15. Unpickler = pickle._Unpickler
  16. Pickler = pickle._Pickler
  17. xrange = range
  18. try:
  19. # The python standard library can be built without bz2 so we make bz2
  20. # usage optional.
  21. # see https://github.com/scikit-learn/scikit-learn/issues/7526 for more
  22. # details.
  23. import bz2
  24. except ImportError:
  25. bz2 = None
  26. # Buffer size used in io.BufferedReader and io.BufferedWriter
  27. _IO_BUFFER_SIZE = 1024**2
  28. def _is_raw_file(fileobj):
  29. """Check if fileobj is a raw file object, e.g created with open."""
  30. fileobj = getattr(fileobj, "raw", fileobj)
  31. return isinstance(fileobj, io.FileIO)
  32. def _get_prefixes_max_len():
  33. # Compute the max prefix len of registered compressors.
  34. prefixes = [len(compressor.prefix) for compressor in _COMPRESSORS.values()]
  35. prefixes += [len(_ZFILE_PREFIX)]
  36. return max(prefixes)
  37. def _is_numpy_array_byte_order_mismatch(array):
  38. """Check if numpy array is having byte order mismatch"""
  39. return (
  40. sys.byteorder == "big"
  41. and (
  42. array.dtype.byteorder == "<"
  43. or (
  44. array.dtype.byteorder == "|"
  45. and array.dtype.fields
  46. and all(e[0].byteorder == "<" for e in array.dtype.fields.values())
  47. )
  48. )
  49. ) or (
  50. sys.byteorder == "little"
  51. and (
  52. array.dtype.byteorder == ">"
  53. or (
  54. array.dtype.byteorder == "|"
  55. and array.dtype.fields
  56. and all(e[0].byteorder == ">" for e in array.dtype.fields.values())
  57. )
  58. )
  59. )
  60. def _ensure_native_byte_order(array):
  61. """Use the byte order of the host while preserving values
  62. Does nothing if array already uses the system byte order.
  63. """
  64. if _is_numpy_array_byte_order_mismatch(array):
  65. array = array.byteswap().view(array.dtype.newbyteorder("="))
  66. return array
  67. ###############################################################################
  68. # Cache file utilities
  69. def _detect_compressor(fileobj):
  70. """Return the compressor matching fileobj.
  71. Parameters
  72. ----------
  73. fileobj: file object
  74. Returns
  75. -------
  76. str in {'zlib', 'gzip', 'bz2', 'lzma', 'xz', 'compat', 'not-compressed'}
  77. """
  78. # Read the magic number in the first bytes of the file.
  79. max_prefix_len = _get_prefixes_max_len()
  80. if hasattr(fileobj, "peek"):
  81. # Peek allows to read those bytes without moving the cursor in the
  82. # file which.
  83. first_bytes = fileobj.peek(max_prefix_len)
  84. else:
  85. # Fallback to seek if the fileobject is not peekable.
  86. first_bytes = fileobj.read(max_prefix_len)
  87. fileobj.seek(0)
  88. if first_bytes.startswith(_ZFILE_PREFIX):
  89. return "compat"
  90. else:
  91. for name, compressor in _COMPRESSORS.items():
  92. if first_bytes.startswith(compressor.prefix):
  93. return name
  94. return "not-compressed"
  95. def _buffered_read_file(fobj):
  96. """Return a buffered version of a read file object."""
  97. return io.BufferedReader(fobj, buffer_size=_IO_BUFFER_SIZE)
  98. def _buffered_write_file(fobj):
  99. """Return a buffered version of a write file object."""
  100. return io.BufferedWriter(fobj, buffer_size=_IO_BUFFER_SIZE)
  101. @contextlib.contextmanager
  102. def _validate_fileobject_and_memmap(fileobj, filename, mmap_mode=None):
  103. """Utility function opening the right fileobject from a filename.
  104. The magic number is used to choose between the type of file object to open:
  105. * regular file object (default)
  106. * zlib file object
  107. * gzip file object
  108. * bz2 file object
  109. * lzma file object (for xz and lzma compressor)
  110. Parameters
  111. ----------
  112. fileobj: file object
  113. filename: str
  114. filename path corresponding to the fileobj parameter.
  115. mmap_mode: str
  116. memory map mode that should be used to open the pickle file. This
  117. parameter is useful to verify that the user is not trying to one with
  118. compression. Default: None.
  119. Returns
  120. -------
  121. a tuple with a file like object, and the validated mmap_mode.
  122. """
  123. # Detect if the fileobj contains compressed data.
  124. compressor = _detect_compressor(fileobj)
  125. validated_mmap_mode = mmap_mode
  126. if compressor == "compat":
  127. # Compatibility with old pickle mode: simply return the input
  128. # filename "as-is" and let the compatibility function be called by the
  129. # caller.
  130. warnings.warn(
  131. "The file '%s' has been generated with a joblib "
  132. "version less than 0.10. "
  133. "Please regenerate this pickle file." % filename,
  134. DeprecationWarning,
  135. stacklevel=2,
  136. )
  137. yield filename, validated_mmap_mode
  138. else:
  139. if compressor in _COMPRESSORS:
  140. # based on the compressor detected in the file, we open the
  141. # correct decompressor file object, wrapped in a buffer.
  142. compressor_wrapper = _COMPRESSORS[compressor]
  143. inst = compressor_wrapper.decompressor_file(fileobj)
  144. fileobj = _buffered_read_file(inst)
  145. # Checking if incompatible load parameters with the type of file:
  146. # mmap_mode cannot be used with compressed file or in memory buffers
  147. # such as io.BytesIO.
  148. if mmap_mode is not None:
  149. validated_mmap_mode = None
  150. if isinstance(fileobj, io.BytesIO):
  151. warnings.warn(
  152. "In memory persistence is not compatible with "
  153. 'mmap_mode "%(mmap_mode)s" flag passed. '
  154. "mmap_mode option will be ignored." % locals(),
  155. stacklevel=2,
  156. )
  157. elif compressor != "not-compressed":
  158. warnings.warn(
  159. 'mmap_mode "%(mmap_mode)s" is not compatible '
  160. "with compressed file %(filename)s. "
  161. '"%(mmap_mode)s" flag will be ignored.' % locals(),
  162. stacklevel=2,
  163. )
  164. elif not _is_raw_file(fileobj):
  165. warnings.warn(
  166. '"%(fileobj)r" is not a raw file, mmap_mode '
  167. '"%(mmap_mode)s" flag will be ignored.' % locals(),
  168. stacklevel=2,
  169. )
  170. else:
  171. validated_mmap_mode = mmap_mode
  172. yield fileobj, validated_mmap_mode
  173. def _write_fileobject(filename, compress=("zlib", 3)):
  174. """Return the right compressor file object in write mode."""
  175. compressmethod = compress[0]
  176. compresslevel = compress[1]
  177. if compressmethod in _COMPRESSORS.keys():
  178. file_instance = _COMPRESSORS[compressmethod].compressor_file(
  179. filename, compresslevel=compresslevel
  180. )
  181. return _buffered_write_file(file_instance)
  182. else:
  183. file_instance = _COMPRESSORS["zlib"].compressor_file(
  184. filename, compresslevel=compresslevel
  185. )
  186. return _buffered_write_file(file_instance)
  187. # Utility functions/variables from numpy required for writing arrays.
  188. # We need at least the functions introduced in version 1.9 of numpy. Here,
  189. # we use the ones from numpy 1.10.2.
  190. BUFFER_SIZE = 2**18 # size of buffer for reading npz files in bytes
  191. def _read_bytes(fp, size, error_template="ran out of data"):
  192. """Read from file-like object until size bytes are read.
  193. TODO python2_drop: is it still needed? The docstring mentions python 2.6
  194. and it looks like this can be at least simplified ...
  195. Raises ValueError if not EOF is encountered before size bytes are read.
  196. Non-blocking objects only supported if they derive from io objects.
  197. Required as e.g. ZipExtFile in python 2.6 can return less data than
  198. requested.
  199. This function was taken from numpy/lib/format.py in version 1.10.2.
  200. Parameters
  201. ----------
  202. fp: file-like object
  203. size: int
  204. error_template: str
  205. Returns
  206. -------
  207. a bytes object
  208. The data read in bytes.
  209. """
  210. data = bytes()
  211. while True:
  212. # io files (default in python3) return None or raise on
  213. # would-block, python2 file will truncate, probably nothing can be
  214. # done about that. note that regular files can't be non-blocking
  215. try:
  216. r = fp.read(size - len(data))
  217. data += r
  218. if len(r) == 0 or len(data) == size:
  219. break
  220. except io.BlockingIOError:
  221. pass
  222. if len(data) != size:
  223. msg = "EOF: reading %s, expected %d bytes got %d"
  224. raise ValueError(msg % (error_template, size, len(data)))
  225. else:
  226. return data
  227. def _reconstruct(*args, **kwargs):
  228. # Wrapper for numpy._core.multiarray._reconstruct with backward compat
  229. # for numpy 1.X
  230. #
  231. # XXX: Remove this function when numpy 1.X is not supported anymore
  232. np_major_version = np.__version__[:2]
  233. if np_major_version == "1.":
  234. from numpy.core.multiarray import _reconstruct as np_reconstruct
  235. elif np_major_version == "2.":
  236. from numpy._core.multiarray import _reconstruct as np_reconstruct
  237. return np_reconstruct(*args, **kwargs)