cloudpickle_wrapper.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import inspect
  2. from functools import partial
  3. from joblib.externals.cloudpickle import dumps, loads
  4. WRAP_CACHE = {}
  5. class CloudpickledObjectWrapper:
  6. def __init__(self, obj, keep_wrapper=False):
  7. self._obj = obj
  8. self._keep_wrapper = keep_wrapper
  9. def __reduce__(self):
  10. _pickled_object = dumps(self._obj)
  11. if not self._keep_wrapper:
  12. return loads, (_pickled_object,)
  13. return _reconstruct_wrapper, (_pickled_object, self._keep_wrapper)
  14. def __getattr__(self, attr):
  15. # Ensure that the wrapped object can be used seemlessly as the
  16. # previous object.
  17. if attr not in ["_obj", "_keep_wrapper"]:
  18. return getattr(self._obj, attr)
  19. return getattr(self, attr)
  20. # Make sure the wrapped object conserves the callable property
  21. class CallableObjectWrapper(CloudpickledObjectWrapper):
  22. def __call__(self, *args, **kwargs):
  23. return self._obj(*args, **kwargs)
  24. def _wrap_non_picklable_objects(obj, keep_wrapper):
  25. if callable(obj):
  26. return CallableObjectWrapper(obj, keep_wrapper=keep_wrapper)
  27. return CloudpickledObjectWrapper(obj, keep_wrapper=keep_wrapper)
  28. def _reconstruct_wrapper(_pickled_object, keep_wrapper):
  29. obj = loads(_pickled_object)
  30. return _wrap_non_picklable_objects(obj, keep_wrapper)
  31. def _wrap_objects_when_needed(obj):
  32. # Function to introspect an object and decide if it should be wrapped or
  33. # not.
  34. need_wrap = "__main__" in getattr(obj, "__module__", "")
  35. if isinstance(obj, partial):
  36. return partial(
  37. _wrap_objects_when_needed(obj.func),
  38. *[_wrap_objects_when_needed(a) for a in obj.args],
  39. **{
  40. k: _wrap_objects_when_needed(v)
  41. for k, v in obj.keywords.items()
  42. },
  43. )
  44. if callable(obj):
  45. # Need wrap if the object is a function defined in a local scope of
  46. # another function.
  47. func_code = getattr(obj, "__code__", "")
  48. need_wrap |= getattr(func_code, "co_flags", 0) & inspect.CO_NESTED
  49. # Need wrap if the obj is a lambda expression
  50. func_name = getattr(obj, "__name__", "")
  51. need_wrap |= "<lambda>" in func_name
  52. if not need_wrap:
  53. return obj
  54. wrapped_obj = WRAP_CACHE.get(obj)
  55. if wrapped_obj is None:
  56. wrapped_obj = _wrap_non_picklable_objects(obj, keep_wrapper=False)
  57. WRAP_CACHE[obj] = wrapped_obj
  58. return wrapped_obj
  59. def wrap_non_picklable_objects(obj, keep_wrapper=True):
  60. """Wrapper for non-picklable object to use cloudpickle to serialize them.
  61. Note that this wrapper tends to slow down the serialization process as it
  62. is done with cloudpickle which is typically slower compared to pickle. The
  63. proper way to solve serialization issues is to avoid defining functions and
  64. objects in the main scripts and to implement __reduce__ functions for
  65. complex classes.
  66. """
  67. # If obj is a class, create a CloudpickledClassWrapper which instantiates
  68. # the object internally and wrap it directly in a CloudpickledObjectWrapper
  69. if inspect.isclass(obj):
  70. class CloudpickledClassWrapper(CloudpickledObjectWrapper):
  71. def __init__(self, *args, **kwargs):
  72. self._obj = obj(*args, **kwargs)
  73. self._keep_wrapper = keep_wrapper
  74. CloudpickledClassWrapper.__name__ = obj.__name__
  75. return CloudpickledClassWrapper
  76. # If obj is an instance of a class, just wrap it in a regular
  77. # CloudpickledObjectWrapper
  78. return _wrap_non_picklable_objects(obj, keep_wrapper=keep_wrapper)