serialization.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. import json
  2. import logging
  3. import types
  4. from ray import cloudpickle as cloudpickle
  5. from ray._common.utils import binary_to_hex, hex_to_binary
  6. from ray.util.annotations import DeveloperAPI
  7. from ray.util.debug import log_once
  8. logger = logging.getLogger(__name__)
  9. @DeveloperAPI
  10. class TuneFunctionEncoder(json.JSONEncoder):
  11. def default(self, obj):
  12. if isinstance(obj, types.FunctionType):
  13. return self._to_cloudpickle(obj)
  14. try:
  15. return super(TuneFunctionEncoder, self).default(obj)
  16. except Exception:
  17. if log_once(f"tune_func_encode:{str(obj)}"):
  18. logger.debug("Unable to encode. Falling back to cloudpickle.")
  19. return self._to_cloudpickle(obj)
  20. def _to_cloudpickle(self, obj):
  21. return {
  22. "_type": "CLOUDPICKLE_FALLBACK",
  23. "value": binary_to_hex(cloudpickle.dumps(obj)),
  24. }
  25. @DeveloperAPI
  26. class TuneFunctionDecoder(json.JSONDecoder):
  27. def __init__(self, *args, **kwargs):
  28. json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
  29. def object_hook(self, obj):
  30. if obj.get("_type") == "CLOUDPICKLE_FALLBACK":
  31. return self._from_cloudpickle(obj)
  32. return obj
  33. def _from_cloudpickle(self, obj):
  34. return cloudpickle.loads(hex_to_binary(obj["value"]))