numpy_pickle_compat.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. """Numpy pickle compatibility functions."""
  2. import inspect
  3. import os
  4. import pickle
  5. import zlib
  6. from io import BytesIO
  7. from .numpy_pickle_utils import (
  8. _ZFILE_PREFIX,
  9. Unpickler,
  10. _ensure_native_byte_order,
  11. _reconstruct,
  12. )
  13. def hex_str(an_int):
  14. """Convert an int to an hexadecimal string."""
  15. return "{:#x}".format(an_int)
  16. def asbytes(s):
  17. if isinstance(s, bytes):
  18. return s
  19. return s.encode("latin1")
  20. _MAX_LEN = len(hex_str(2**64))
  21. _CHUNK_SIZE = 64 * 1024
  22. def read_zfile(file_handle):
  23. """Read the z-file and return the content as a string.
  24. Z-files are raw data compressed with zlib used internally by joblib
  25. for persistence. Backward compatibility is not guaranteed. Do not
  26. use for external purposes.
  27. """
  28. file_handle.seek(0)
  29. header_length = len(_ZFILE_PREFIX) + _MAX_LEN
  30. length = file_handle.read(header_length)
  31. length = length[len(_ZFILE_PREFIX) :]
  32. length = int(length, 16)
  33. # With python2 and joblib version <= 0.8.4 compressed pickle header is one
  34. # character wider so we need to ignore an additional space if present.
  35. # Note: the first byte of the zlib data is guaranteed not to be a
  36. # space according to
  37. # https://tools.ietf.org/html/rfc6713#section-2.1
  38. next_byte = file_handle.read(1)
  39. if next_byte != b" ":
  40. # The zlib compressed data has started and we need to go back
  41. # one byte
  42. file_handle.seek(header_length)
  43. # We use the known length of the data to tell Zlib the size of the
  44. # buffer to allocate.
  45. data = zlib.decompress(file_handle.read(), 15, length)
  46. assert len(data) == length, (
  47. "Incorrect data length while decompressing %s."
  48. "The file could be corrupted." % file_handle
  49. )
  50. return data
  51. def write_zfile(file_handle, data, compress=1):
  52. """Write the data in the given file as a Z-file.
  53. Z-files are raw data compressed with zlib used internally by joblib
  54. for persistence. Backward compatibility is not guaranteed. Do not
  55. use for external purposes.
  56. """
  57. file_handle.write(_ZFILE_PREFIX)
  58. length = hex_str(len(data))
  59. # Store the length of the data
  60. file_handle.write(asbytes(length.ljust(_MAX_LEN)))
  61. file_handle.write(zlib.compress(asbytes(data), compress))
  62. ###############################################################################
  63. # Utility objects for persistence.
  64. class NDArrayWrapper(object):
  65. """An object to be persisted instead of numpy arrays.
  66. The only thing this object does, is to carry the filename in which
  67. the array has been persisted, and the array subclass.
  68. """
  69. def __init__(self, filename, subclass, allow_mmap=True):
  70. """Constructor. Store the useful information for later."""
  71. self.filename = filename
  72. self.subclass = subclass
  73. self.allow_mmap = allow_mmap
  74. def read(self, unpickler):
  75. """Reconstruct the array."""
  76. filename = os.path.join(unpickler._dirname, self.filename)
  77. # Load the array from the disk
  78. # use getattr instead of self.allow_mmap to ensure backward compat
  79. # with NDArrayWrapper instances pickled with joblib < 0.9.0
  80. allow_mmap = getattr(self, "allow_mmap", True)
  81. kwargs = {}
  82. if allow_mmap:
  83. kwargs["mmap_mode"] = unpickler.mmap_mode
  84. if "allow_pickle" in inspect.signature(unpickler.np.load).parameters:
  85. # Required in numpy 1.16.3 and later to acknowledge the security
  86. # risk.
  87. kwargs["allow_pickle"] = True
  88. array = unpickler.np.load(filename, **kwargs)
  89. # Detect byte order mismatch and swap as needed.
  90. array = _ensure_native_byte_order(array)
  91. # Reconstruct subclasses. This does not work with old
  92. # versions of numpy
  93. if hasattr(array, "__array_prepare__") and self.subclass not in (
  94. unpickler.np.ndarray,
  95. unpickler.np.memmap,
  96. ):
  97. # We need to reconstruct another subclass
  98. new_array = _reconstruct(self.subclass, (0,), "b")
  99. return new_array.__array_prepare__(array)
  100. else:
  101. return array
  102. class ZNDArrayWrapper(NDArrayWrapper):
  103. """An object to be persisted instead of numpy arrays.
  104. This object store the Zfile filename in which
  105. the data array has been persisted, and the meta information to
  106. retrieve it.
  107. The reason that we store the raw buffer data of the array and
  108. the meta information, rather than array representation routine
  109. (tobytes) is that it enables us to use completely the strided
  110. model to avoid memory copies (a and a.T store as fast). In
  111. addition saving the heavy information separately can avoid
  112. creating large temporary buffers when unpickling data with
  113. large arrays.
  114. """
  115. def __init__(self, filename, init_args, state):
  116. """Constructor. Store the useful information for later."""
  117. self.filename = filename
  118. self.state = state
  119. self.init_args = init_args
  120. def read(self, unpickler):
  121. """Reconstruct the array from the meta-information and the z-file."""
  122. # Here we a simply reproducing the unpickling mechanism for numpy
  123. # arrays
  124. filename = os.path.join(unpickler._dirname, self.filename)
  125. array = _reconstruct(*self.init_args)
  126. with open(filename, "rb") as f:
  127. data = read_zfile(f)
  128. state = self.state + (data,)
  129. array.__setstate__(state)
  130. return array
  131. class ZipNumpyUnpickler(Unpickler):
  132. """A subclass of the Unpickler to unpickle our numpy pickles."""
  133. dispatch = Unpickler.dispatch.copy()
  134. def __init__(self, filename, file_handle, mmap_mode=None):
  135. """Constructor."""
  136. self._filename = os.path.basename(filename)
  137. self._dirname = os.path.dirname(filename)
  138. self.mmap_mode = mmap_mode
  139. self.file_handle = self._open_pickle(file_handle)
  140. Unpickler.__init__(self, self.file_handle)
  141. try:
  142. import numpy as np
  143. except ImportError:
  144. np = None
  145. self.np = np
  146. def _open_pickle(self, file_handle):
  147. return BytesIO(read_zfile(file_handle))
  148. def load_build(self):
  149. """Set the state of a newly created object.
  150. We capture it to replace our place-holder objects,
  151. NDArrayWrapper, by the array we are interested in. We
  152. replace them directly in the stack of pickler.
  153. """
  154. Unpickler.load_build(self)
  155. if isinstance(self.stack[-1], NDArrayWrapper):
  156. if self.np is None:
  157. raise ImportError(
  158. "Trying to unpickle an ndarray, but numpy didn't import correctly"
  159. )
  160. nd_array_wrapper = self.stack.pop()
  161. array = nd_array_wrapper.read(self)
  162. self.stack.append(array)
  163. dispatch[pickle.BUILD[0]] = load_build
  164. def load_compatibility(filename):
  165. """Reconstruct a Python object from a file persisted with joblib.dump.
  166. This function ensures the compatibility with joblib old persistence format
  167. (<= 0.9.3).
  168. Parameters
  169. ----------
  170. filename: string
  171. The name of the file from which to load the object
  172. Returns
  173. -------
  174. result: any Python object
  175. The object stored in the file.
  176. See Also
  177. --------
  178. joblib.dump : function to save an object
  179. Notes
  180. -----
  181. This function can load numpy array files saved separately during the
  182. dump.
  183. """
  184. with open(filename, "rb") as file_handle:
  185. # We are careful to open the file handle early and keep it open to
  186. # avoid race-conditions on renames. That said, if data is stored in
  187. # companion files, moving the directory will create a race when
  188. # joblib tries to access the companion files.
  189. unpickler = ZipNumpyUnpickler(filename, file_handle=file_handle)
  190. try:
  191. obj = unpickler.load()
  192. except UnicodeDecodeError as exc:
  193. # More user-friendly error message
  194. new_exc = ValueError(
  195. "You may be trying to read with "
  196. "python 3 a joblib pickle generated with python 2. "
  197. "This feature is not supported by joblib."
  198. )
  199. new_exc.__cause__ = exc
  200. raise new_exc
  201. finally:
  202. if hasattr(unpickler, "file_handle"):
  203. unpickler.file_handle.close()
  204. return obj