serialization_handlers.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. """
  2. Serialization handlers for preprocessor save/load functionality.
  3. This module implements a factory pattern to abstract different serialization formats,
  4. making it easier to add new formats and maintain existing ones.
  5. """
  6. import abc
  7. import base64
  8. import pickle
  9. from enum import Enum
  10. from typing import Any, Dict, Optional, Union
  11. from ray.cloudpickle import cloudpickle
  12. from ray.util.annotations import DeveloperAPI
  13. @DeveloperAPI
  14. class HandlerFormatName(Enum):
  15. """Enum for consistent format naming in the factory."""
  16. CLOUDPICKLE = "cloudpickle"
  17. PICKLE = "pickle"
  18. @DeveloperAPI
  19. class SerializationHandler(abc.ABC):
  20. """Abstract base class for handling preprocessor serialization formats."""
  21. @abc.abstractmethod
  22. def serialize(
  23. self, data: Union["Preprocessor", Dict[str, Any]] # noqa: F821
  24. ) -> Union[str, bytes]:
  25. """Serialize preprocessor data to the specific format.
  26. Args:
  27. data: Dictionary containing preprocessor metadata and stats
  28. Returns:
  29. Serialized data in format-specific representation
  30. """
  31. pass
  32. @abc.abstractmethod
  33. def deserialize(self, serialized: Union[str, bytes]) -> Any:
  34. """Deserialize data from the specific format.
  35. Args:
  36. serialized: Serialized data in format-specific representation
  37. Returns:
  38. For structured formats (CloudPickle/JSON/MessagePack): Dictionary containing preprocessor metadata and stats
  39. For pickle format: The actual deserialized object
  40. """
  41. pass
  42. @abc.abstractmethod
  43. def get_magic_bytes(self) -> Union[str, bytes]:
  44. """Get the magic bytes/prefix for this format."""
  45. pass
  46. def strip_magic_bytes(self, serialized: Union[str, bytes]) -> Union[str, bytes]:
  47. """Remove magic bytes from serialized data."""
  48. magic = self.get_magic_bytes()
  49. if isinstance(serialized, (str, bytes)) and serialized.startswith(magic):
  50. return serialized[len(magic) :]
  51. return serialized
  52. @DeveloperAPI
  53. class CloudPickleSerializationHandler(SerializationHandler):
  54. """Handler for CloudPickle serialization format."""
  55. MAGIC_CLOUDPICKLE = b"CPKL:"
  56. def serialize(
  57. self, data: Union["Preprocessor", Dict[str, Any]] # noqa: F821
  58. ) -> bytes:
  59. """Serialize to CloudPickle format with magic prefix."""
  60. return self.MAGIC_CLOUDPICKLE + cloudpickle.dumps(data)
  61. def deserialize(self, serialized: bytes) -> Dict[str, Any]:
  62. """Deserialize from CloudPickle format."""
  63. if not isinstance(serialized, bytes):
  64. raise ValueError(
  65. f"Expected bytes for CloudPickle deserialization, got {type(serialized)}"
  66. )
  67. if not serialized.startswith(self.MAGIC_CLOUDPICKLE):
  68. raise ValueError(f"Invalid CloudPickle magic bytes: {serialized[:10]}")
  69. cloudpickle_data = self.strip_magic_bytes(serialized)
  70. return cloudpickle.loads(cloudpickle_data)
  71. def get_magic_bytes(self) -> bytes:
  72. return self.MAGIC_CLOUDPICKLE
  73. @DeveloperAPI
  74. class PickleSerializationHandler(SerializationHandler):
  75. """Handler for legacy Pickle serialization format."""
  76. def serialize(
  77. self, data: Union["Preprocessor", Dict[str, Any]] # noqa: F821
  78. ) -> str:
  79. """
  80. Serialize using pickle format (for backward compatibility).
  81. data is ignored, but kept for consistency
  82. """
  83. return base64.b64encode(pickle.dumps(data)).decode("ascii")
  84. def deserialize(
  85. self, serialized: str
  86. ) -> Any: # Returns the actual object, not metadata
  87. """Deserialize from pickle format (legacy support)."""
  88. # For pickle, we return the actual deserialized object directly
  89. return pickle.loads(base64.b64decode(serialized))
  90. def get_magic_bytes(self) -> str:
  91. return "" # Pickle format doesn't use magic bytes
  92. class SerializationHandlerFactory:
  93. """Factory class for creating appropriate serialization handlers."""
  94. _handlers = {
  95. HandlerFormatName.CLOUDPICKLE: CloudPickleSerializationHandler,
  96. HandlerFormatName.PICKLE: PickleSerializationHandler,
  97. }
  98. @classmethod
  99. def register_handler(cls, format_name: HandlerFormatName, handler_class: type):
  100. """Register a new serialization handler."""
  101. cls._handlers[format_name] = handler_class
  102. @classmethod
  103. def get_handler(
  104. cls,
  105. format_identifier: Optional[HandlerFormatName] = None,
  106. data: Optional[Union[str, bytes]] = None,
  107. **kwargs,
  108. ) -> SerializationHandler:
  109. """Get the appropriate serialization handler for a format or serialized data.
  110. Args:
  111. format_identifier: The format to use for serialization. If None, will detect from data.
  112. data: Serialized data to detect format from (used when format_identifier is None).
  113. **kwargs: Additional keyword arguments (currently unused).
  114. Returns:
  115. SerializationHandler instance for the format
  116. Raises:
  117. ValueError: If format is not supported or cannot be detected
  118. """
  119. # If it's already a format enum, use it directly
  120. if not format_identifier:
  121. format_identifier = cls.detect_format(data)
  122. if format_identifier not in cls._handlers:
  123. raise ValueError(
  124. f"Unsupported serialization format: {format_identifier.value}. "
  125. f"Supported formats: {list(cls._handlers.keys())}"
  126. )
  127. handler_class = cls._handlers[format_identifier]
  128. return handler_class()
  129. @classmethod
  130. def detect_format(cls, serialized: Union[str, bytes]) -> HandlerFormatName:
  131. """Detect the serialization format from the magic bytes.
  132. Args:
  133. serialized: Serialized data
  134. Returns:
  135. Format name enum
  136. Raises:
  137. ValueError: If format cannot be detected
  138. """
  139. # Check for CloudPickle first (binary format)
  140. if isinstance(serialized, bytes) and serialized.startswith(
  141. CloudPickleSerializationHandler.MAGIC_CLOUDPICKLE
  142. ):
  143. return HandlerFormatName.CLOUDPICKLE
  144. # Check for legacy pickle format (no magic bytes, should be base64 encoded)
  145. if isinstance(serialized, str):
  146. return HandlerFormatName.PICKLE
  147. raise ValueError(
  148. f"Cannot detect serialization format from: {serialized[:20]}..."
  149. )