serialization.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. import logging
  2. import pickle
  3. from typing import Any, Dict, Tuple
  4. from ray import cloudpickle
  5. from ray.serve._private.constants import SERVE_LOGGER_NAME
  6. try:
  7. import orjson
  8. except ImportError:
  9. orjson = None
  10. try:
  11. import ormsgpack
  12. except ImportError:
  13. ormsgpack = None
  14. logger = logging.getLogger(SERVE_LOGGER_NAME)
  15. class SerializationMethod:
  16. """Available serialization methods for RPC communication."""
  17. CLOUDPICKLE = "cloudpickle"
  18. PICKLE = "pickle"
  19. MSGPACK = "msgpack"
  20. ORJSON = "orjson"
  21. NOOP = "noop"
  22. # Global cache for serializer instances to avoid per-request instantiation overhead
  23. _serializer_cache: Dict[Tuple[str, str], "RPCSerializer"] = {}
  24. class RPCSerializer:
  25. """Serializer for RPC communication with configurable serialization methods."""
  26. def __init__(
  27. self,
  28. request_method: str = SerializationMethod.CLOUDPICKLE,
  29. response_method: str = SerializationMethod.CLOUDPICKLE,
  30. ):
  31. self.request_method = request_method.lower()
  32. self.response_method = response_method.lower()
  33. self._validate_methods()
  34. self._setup_serializers()
  35. @classmethod
  36. def get_cached_serializer(
  37. cls,
  38. request_method: str = SerializationMethod.CLOUDPICKLE,
  39. response_method: str = SerializationMethod.CLOUDPICKLE,
  40. ) -> "RPCSerializer":
  41. """Get a cached serializer instance to avoid per-request instantiation overhead.
  42. This method maintains a cache of serializer instances based on
  43. (request_method, response_method) pairs, significantly reducing overhead
  44. in high-throughput systems.
  45. """
  46. # Normalize method names
  47. req_method = request_method.lower()
  48. resp_method = response_method.lower()
  49. cache_key = (req_method, resp_method)
  50. if cache_key not in _serializer_cache:
  51. _serializer_cache[cache_key] = cls(req_method, resp_method)
  52. return _serializer_cache[cache_key]
  53. def _validate_methods(self):
  54. """Validate that the serialization methods are supported."""
  55. valid_methods = {
  56. SerializationMethod.CLOUDPICKLE,
  57. SerializationMethod.PICKLE,
  58. SerializationMethod.MSGPACK,
  59. SerializationMethod.ORJSON,
  60. SerializationMethod.NOOP,
  61. }
  62. if self.request_method not in valid_methods:
  63. raise ValueError(
  64. f"Unsupported request serialization method: {self.request_method}. "
  65. f"Valid options: {valid_methods}"
  66. )
  67. if self.response_method not in valid_methods:
  68. raise ValueError(
  69. f"Unsupported response serialization method: {self.response_method}. "
  70. f"Valid options: {valid_methods}"
  71. )
  72. def _setup_serializers(self):
  73. """Setup the serialization functions based on the selected methods."""
  74. self._request_dumps, self._request_loads = self._get_serializer_funcs(
  75. self.request_method
  76. )
  77. self._response_dumps, self._response_loads = self._get_serializer_funcs(
  78. self.response_method
  79. )
  80. def _get_serializer_funcs(self, method: str) -> Tuple[Any, Any]:
  81. """Get dumps and loads functions for a given serialization method."""
  82. if method == SerializationMethod.CLOUDPICKLE:
  83. return cloudpickle.dumps, cloudpickle.loads
  84. elif method == SerializationMethod.PICKLE:
  85. return self._get_pickle_funcs()
  86. elif method == SerializationMethod.MSGPACK:
  87. return self._get_msgpack_funcs()
  88. elif method == SerializationMethod.ORJSON:
  89. return self._get_orjson_funcs()
  90. elif method == SerializationMethod.NOOP:
  91. return self._get_noop_funcs()
  92. def _get_noop_funcs(self) -> Tuple[Any, Any]:
  93. """Get no-op serialization functions for binary data."""
  94. def _noop_dumps(obj: Any) -> bytes:
  95. if not isinstance(obj, bytes):
  96. raise TypeError(
  97. f"a bytes-like object is required, got {type(obj).__name__}. "
  98. "Use a different serialization method for non-binary data."
  99. )
  100. return obj
  101. def _noop_loads(data: bytes) -> Any:
  102. if not isinstance(data, bytes):
  103. raise TypeError(
  104. f"a bytes-like object is required, got {type(data).__name__}. "
  105. "Use a different serialization method for non-binary data."
  106. )
  107. return data
  108. return _noop_dumps, _noop_loads
  109. def _get_pickle_funcs(self) -> Tuple[Any, Any]:
  110. """Get pickle serialization functions with highest protocol."""
  111. def _pickle_dumps(obj: Any) -> bytes:
  112. return pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
  113. def _pickle_loads(data: bytes) -> Any:
  114. return pickle.loads(data)
  115. return _pickle_dumps, _pickle_loads
  116. def _get_msgpack_funcs(self) -> Tuple[Any, Any]:
  117. """Get msgpack serialization functions."""
  118. if ormsgpack is None:
  119. raise ImportError(
  120. "ormsgpack is not installed. Please install it with `pip install ormsgpack`."
  121. )
  122. # Configure ormsgpack with appropriate options
  123. def _msgpack_dumps(obj: Any) -> bytes:
  124. return ormsgpack.packb(obj)
  125. def _msgpack_loads(data: bytes) -> Any:
  126. return ormsgpack.unpackb(data)
  127. return _msgpack_dumps, _msgpack_loads
  128. def _get_orjson_funcs(self) -> Tuple[Any, Any]:
  129. """Get orjson serialization functions."""
  130. if orjson is None:
  131. raise ImportError(
  132. "orjson is not installed. Please install it with `pip install orjson`."
  133. )
  134. # orjson only supports JSON-serializable types
  135. def _orjson_dumps(obj: Any) -> bytes:
  136. try:
  137. return orjson.dumps(obj)
  138. except TypeError as e:
  139. raise TypeError(
  140. f"orjson serialization failed: {e}. "
  141. "Only JSON-serializable types are supported with orjson. "
  142. "Consider using 'cloudpickle' or 'pickle' for complex objects."
  143. )
  144. def _orjson_loads(data: bytes) -> Any:
  145. return orjson.loads(data)
  146. return _orjson_dumps, _orjson_loads
  147. def dumps_request(self, obj: Any) -> bytes:
  148. """Serialize a request object to bytes."""
  149. return self._request_dumps(obj)
  150. def loads_request(self, data: bytes) -> Any:
  151. """Deserialize bytes to a request object."""
  152. return self._request_loads(data)
  153. def dumps_response(self, obj: Any) -> bytes:
  154. """Serialize a response object to bytes."""
  155. return self._response_dumps(obj)
  156. def loads_response(self, data: bytes) -> Any:
  157. """Deserialize bytes to a response object."""
  158. return self._response_loads(data)
  159. def clear_serializer_cache():
  160. """Clear the cached serializer instances. Useful for testing or memory management."""
  161. global _serializer_cache
  162. _serializer_cache.clear()