reduction.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. ###############################################################################
  2. # Customizable Pickler with some basic reducers
  3. #
  4. # author: Thomas Moreau
  5. #
  6. # adapted from multiprocessing/reduction.py (17/02/2017)
  7. # * Replace the ForkingPickler with a similar _LokyPickler,
  8. # * Add CustomizableLokyPickler to allow customizing pickling process
  9. # on the fly.
  10. #
  11. import copyreg
  12. import io
  13. import functools
  14. import types
  15. import sys
  16. import os
  17. from multiprocessing import util
  18. from pickle import loads, HIGHEST_PROTOCOL
  19. ###############################################################################
  20. # Enable custom pickling in Loky.
  21. _dispatch_table = {}
  22. def register(type_, reduce_function):
  23. _dispatch_table[type_] = reduce_function
  24. ###############################################################################
  25. # Registers extra pickling routines to improve picklization for loky
  26. # make methods picklable
  27. def _reduce_method(m):
  28. if m.__self__ is None:
  29. return getattr, (m.__class__, m.__func__.__name__)
  30. else:
  31. return getattr, (m.__self__, m.__func__.__name__)
  32. class _C:
  33. def f(self):
  34. pass
  35. @classmethod
  36. def h(cls):
  37. pass
  38. register(type(_C().f), _reduce_method)
  39. register(type(_C.h), _reduce_method)
  40. def _reduce_method_descriptor(m):
  41. return getattr, (m.__objclass__, m.__name__)
  42. register(type(list.append), _reduce_method_descriptor)
  43. register(type(int.__add__), _reduce_method_descriptor)
  44. # Make partial func pickable
  45. def _reduce_partial(p):
  46. return _rebuild_partial, (p.func, p.args, p.keywords or {})
  47. def _rebuild_partial(func, args, keywords):
  48. return functools.partial(func, *args, **keywords)
  49. register(functools.partial, _reduce_partial)
  50. if sys.platform != "win32":
  51. from ._posix_reduction import _mk_inheritable # noqa: F401
  52. else:
  53. from . import _win_reduction # noqa: F401
  54. # global variable to change the pickler behavior
  55. try:
  56. from joblib.externals import cloudpickle # noqa: F401
  57. DEFAULT_ENV = "cloudpickle"
  58. except ImportError:
  59. # If cloudpickle is not present, fallback to pickle
  60. DEFAULT_ENV = "pickle"
  61. ENV_LOKY_PICKLER = os.environ.get("LOKY_PICKLER", DEFAULT_ENV)
  62. _LokyPickler = None
  63. _loky_pickler_name = None
  64. def set_loky_pickler(loky_pickler=None):
  65. global _LokyPickler, _loky_pickler_name
  66. if loky_pickler is None:
  67. loky_pickler = ENV_LOKY_PICKLER
  68. loky_pickler_cls = None
  69. # The default loky_pickler is cloudpickle
  70. if loky_pickler in ["", None]:
  71. loky_pickler = "cloudpickle"
  72. if loky_pickler == _loky_pickler_name:
  73. return
  74. if loky_pickler == "cloudpickle":
  75. from joblib.externals.cloudpickle import CloudPickler as loky_pickler_cls
  76. else:
  77. try:
  78. from importlib import import_module
  79. module_pickle = import_module(loky_pickler)
  80. loky_pickler_cls = module_pickle.Pickler
  81. except (ImportError, AttributeError) as e:
  82. extra_info = (
  83. "\nThis error occurred while setting loky_pickler to"
  84. f" '{loky_pickler}', as required by the env variable "
  85. "LOKY_PICKLER or the function set_loky_pickler."
  86. )
  87. e.args = (e.args[0] + extra_info,) + e.args[1:]
  88. e.msg = e.args[0]
  89. raise e
  90. util.debug(
  91. f"Using '{loky_pickler if loky_pickler else 'cloudpickle'}' for "
  92. "serialization."
  93. )
  94. class CustomizablePickler(loky_pickler_cls):
  95. _loky_pickler_cls = loky_pickler_cls
  96. def _set_dispatch_table(self, dispatch_table):
  97. for ancestor_class in self._loky_pickler_cls.mro():
  98. dt_attribute = getattr(ancestor_class, "dispatch_table", None)
  99. if isinstance(dt_attribute, types.MemberDescriptorType):
  100. # Ancestor class (typically _pickle.Pickler) has a
  101. # member_descriptor for its "dispatch_table" attribute. Use
  102. # it to set the dispatch_table as a member instead of a
  103. # dynamic attribute in the __dict__ of the instance,
  104. # otherwise it will not be taken into account by the C
  105. # implementation of the dump method if a subclass defines a
  106. # class-level dispatch_table attribute as was done in
  107. # cloudpickle 1.6.0:
  108. # https://github.com/joblib/loky/pull/260
  109. dt_attribute.__set__(self, dispatch_table)
  110. break
  111. # On top of member descriptor set, also use setattr such that code
  112. # that directly access self.dispatch_table gets a consistent view
  113. # of the same table.
  114. self.dispatch_table = dispatch_table
  115. def __init__(self, writer, reducers=None, protocol=HIGHEST_PROTOCOL):
  116. loky_pickler_cls.__init__(self, writer, protocol=protocol)
  117. if reducers is None:
  118. reducers = {}
  119. if hasattr(self, "dispatch_table"):
  120. # Force a copy that we will update without mutating the
  121. # any class level defined dispatch_table.
  122. loky_dt = dict(self.dispatch_table)
  123. else:
  124. # Use standard reducers as bases
  125. loky_dt = copyreg.dispatch_table.copy()
  126. # Register loky specific reducers
  127. loky_dt.update(_dispatch_table)
  128. # Set the new dispatch table, taking care of the fact that we
  129. # need to use the member_descriptor when we inherit from a
  130. # subclass of the C implementation of the Pickler base class
  131. # with an class level dispatch_table attribute.
  132. self._set_dispatch_table(loky_dt)
  133. # Register the reducers
  134. for type, reduce_func in reducers.items():
  135. self.register(type, reduce_func)
  136. def register(self, type, reduce_func):
  137. """Attach a reducer function to a given type in the dispatch table."""
  138. self.dispatch_table[type] = reduce_func
  139. _LokyPickler = CustomizablePickler
  140. _loky_pickler_name = loky_pickler
  141. def get_loky_pickler_name():
  142. global _loky_pickler_name
  143. return _loky_pickler_name
  144. def get_loky_pickler():
  145. global _LokyPickler
  146. return _LokyPickler
  147. # Set it to its default value
  148. set_loky_pickler()
  149. def dump(obj, file, reducers=None, protocol=None):
  150. """Replacement for pickle.dump() using _LokyPickler."""
  151. global _LokyPickler
  152. _LokyPickler(file, reducers=reducers, protocol=protocol).dump(obj)
  153. def dumps(obj, reducers=None, protocol=None):
  154. global _LokyPickler
  155. buf = io.BytesIO()
  156. dump(obj, buf, reducers=reducers, protocol=protocol)
  157. return buf.getbuffer()
  158. __all__ = ["dump", "dumps", "loads", "register", "set_loky_pickler"]
  159. if sys.platform == "win32":
  160. from multiprocessing.reduction import duplicate
  161. __all__ += ["duplicate"]