| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195 |
- """
- Serialization handlers for preprocessor save/load functionality.
- This module implements a factory pattern to abstract different serialization formats,
- making it easier to add new formats and maintain existing ones.
- """
- import abc
- import base64
- import pickle
- from enum import Enum
- from typing import Any, Dict, Optional, Union
- from ray.cloudpickle import cloudpickle
- from ray.util.annotations import DeveloperAPI
- @DeveloperAPI
- class HandlerFormatName(Enum):
- """Enum for consistent format naming in the factory."""
- CLOUDPICKLE = "cloudpickle"
- PICKLE = "pickle"
- @DeveloperAPI
- class SerializationHandler(abc.ABC):
- """Abstract base class for handling preprocessor serialization formats."""
- @abc.abstractmethod
- def serialize(
- self, data: Union["Preprocessor", Dict[str, Any]] # noqa: F821
- ) -> Union[str, bytes]:
- """Serialize preprocessor data to the specific format.
- Args:
- data: Dictionary containing preprocessor metadata and stats
- Returns:
- Serialized data in format-specific representation
- """
- pass
- @abc.abstractmethod
- def deserialize(self, serialized: Union[str, bytes]) -> Any:
- """Deserialize data from the specific format.
- Args:
- serialized: Serialized data in format-specific representation
- Returns:
- For structured formats (CloudPickle/JSON/MessagePack): Dictionary containing preprocessor metadata and stats
- For pickle format: The actual deserialized object
- """
- pass
- @abc.abstractmethod
- def get_magic_bytes(self) -> Union[str, bytes]:
- """Get the magic bytes/prefix for this format."""
- pass
- def strip_magic_bytes(self, serialized: Union[str, bytes]) -> Union[str, bytes]:
- """Remove magic bytes from serialized data."""
- magic = self.get_magic_bytes()
- if isinstance(serialized, (str, bytes)) and serialized.startswith(magic):
- return serialized[len(magic) :]
- return serialized
- @DeveloperAPI
- class CloudPickleSerializationHandler(SerializationHandler):
- """Handler for CloudPickle serialization format."""
- MAGIC_CLOUDPICKLE = b"CPKL:"
- def serialize(
- self, data: Union["Preprocessor", Dict[str, Any]] # noqa: F821
- ) -> bytes:
- """Serialize to CloudPickle format with magic prefix."""
- return self.MAGIC_CLOUDPICKLE + cloudpickle.dumps(data)
- def deserialize(self, serialized: bytes) -> Dict[str, Any]:
- """Deserialize from CloudPickle format."""
- if not isinstance(serialized, bytes):
- raise ValueError(
- f"Expected bytes for CloudPickle deserialization, got {type(serialized)}"
- )
- if not serialized.startswith(self.MAGIC_CLOUDPICKLE):
- raise ValueError(f"Invalid CloudPickle magic bytes: {serialized[:10]}")
- cloudpickle_data = self.strip_magic_bytes(serialized)
- return cloudpickle.loads(cloudpickle_data)
- def get_magic_bytes(self) -> bytes:
- return self.MAGIC_CLOUDPICKLE
- @DeveloperAPI
- class PickleSerializationHandler(SerializationHandler):
- """Handler for legacy Pickle serialization format."""
- def serialize(
- self, data: Union["Preprocessor", Dict[str, Any]] # noqa: F821
- ) -> str:
- """
- Serialize using pickle format (for backward compatibility).
- data is ignored, but kept for consistency
- """
- return base64.b64encode(pickle.dumps(data)).decode("ascii")
- def deserialize(
- self, serialized: str
- ) -> Any: # Returns the actual object, not metadata
- """Deserialize from pickle format (legacy support)."""
- # For pickle, we return the actual deserialized object directly
- return pickle.loads(base64.b64decode(serialized))
- def get_magic_bytes(self) -> str:
- return "" # Pickle format doesn't use magic bytes
- class SerializationHandlerFactory:
- """Factory class for creating appropriate serialization handlers."""
- _handlers = {
- HandlerFormatName.CLOUDPICKLE: CloudPickleSerializationHandler,
- HandlerFormatName.PICKLE: PickleSerializationHandler,
- }
- @classmethod
- def register_handler(cls, format_name: HandlerFormatName, handler_class: type):
- """Register a new serialization handler."""
- cls._handlers[format_name] = handler_class
- @classmethod
- def get_handler(
- cls,
- format_identifier: Optional[HandlerFormatName] = None,
- data: Optional[Union[str, bytes]] = None,
- **kwargs,
- ) -> SerializationHandler:
- """Get the appropriate serialization handler for a format or serialized data.
- Args:
- format_identifier: The format to use for serialization. If None, will detect from data.
- data: Serialized data to detect format from (used when format_identifier is None).
- **kwargs: Additional keyword arguments (currently unused).
- Returns:
- SerializationHandler instance for the format
- Raises:
- ValueError: If format is not supported or cannot be detected
- """
- # If it's already a format enum, use it directly
- if not format_identifier:
- format_identifier = cls.detect_format(data)
- if format_identifier not in cls._handlers:
- raise ValueError(
- f"Unsupported serialization format: {format_identifier.value}. "
- f"Supported formats: {list(cls._handlers.keys())}"
- )
- handler_class = cls._handlers[format_identifier]
- return handler_class()
- @classmethod
- def detect_format(cls, serialized: Union[str, bytes]) -> HandlerFormatName:
- """Detect the serialization format from the magic bytes.
- Args:
- serialized: Serialized data
- Returns:
- Format name enum
- Raises:
- ValueError: If format cannot be detected
- """
- # Check for CloudPickle first (binary format)
- if isinstance(serialized, bytes) and serialized.startswith(
- CloudPickleSerializationHandler.MAGIC_CLOUDPICKLE
- ):
- return HandlerFormatName.CLOUDPICKLE
- # Check for legacy pickle format (no magic bytes, should be base64 encoded)
- if isinstance(serialized, str):
- return HandlerFormatName.PICKLE
- raise ValueError(
- f"Cannot detect serialization format from: {serialized[:20]}..."
- )
|