preprocessor.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786
  1. import abc
  2. import base64
  3. import collections
  4. import logging
  5. import pickle
  6. import warnings
  7. from enum import Enum
  8. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, final
  9. from ray.data.util.data_batch_conversion import BatchFormat
  10. from ray.util.annotations import DeveloperAPI, PublicAPI
  11. if TYPE_CHECKING:
  12. import numpy as np
  13. import pandas as pd
  14. import pyarrow
  15. from ray.air.data_batch_type import DataBatchType
  16. from ray.data.dataset import Dataset
  17. logger = logging.getLogger(__name__)
  18. @PublicAPI(stability="beta")
  19. class PreprocessorNotFittedException(RuntimeError):
  20. """Error raised when the preprocessor needs to be fitted first."""
  21. pass
  22. @PublicAPI(stability="beta")
  23. class Preprocessor(abc.ABC):
  24. """Implements an ML preprocessing operation.
  25. Preprocessors are stateful objects that can be fitted against a Dataset and used
  26. to transform both local data batches and distributed data. For example, a
  27. Normalization preprocessor may calculate the mean and stdev of a field during
  28. fitting, and uses these attributes to implement its normalization transform.
  29. Preprocessors can also be stateless and transform data without needed to be fitted.
  30. For example, a preprocessor may simply remove a column, which does not require
  31. any state to be fitted.
  32. If you are implementing your own Preprocessor sub-class, you should override the
  33. following:
  34. * ``_fit`` if your preprocessor is stateful. Otherwise, set
  35. ``_is_fittable=False``.
  36. * ``_transform_pandas`` and/or ``_transform_numpy`` for best performance,
  37. implement both. Otherwise, the data will be converted to the match the
  38. implemented method.
  39. """
  40. def __init__(self):
  41. from ray.data.preprocessors.utils import StatComputationPlan
  42. self.stat_computation_plan = StatComputationPlan()
  43. self.stats_ = {}
  44. class FitStatus(str, Enum):
  45. """The fit status of preprocessor."""
  46. NOT_FITTABLE = "NOT_FITTABLE"
  47. NOT_FITTED = "NOT_FITTED"
  48. # Only meaningful for Chain preprocessors.
  49. # At least one contained preprocessor in the chain preprocessor
  50. # is fitted and at least one that can be fitted is not fitted yet.
  51. # This is a state that show up if caller only interacts
  52. # with the chain preprocessor through intended Preprocessor APIs.
  53. PARTIALLY_FITTED = "PARTIALLY_FITTED"
  54. FITTED = "FITTED"
  55. # Preprocessors that do not need to be fitted must override this.
  56. _is_fittable = True
  57. def _check_has_fitted_state(self):
  58. """Checks if the Preprocessor has fitted state.
  59. This is also used as an indication if the Preprocessor has been fit, following
  60. convention from Ray versions prior to 2.6.
  61. This allows preprocessors that have been fit in older versions of Ray to be
  62. used to transform data in newer versions.
  63. """
  64. fitted_vars = [v for v in vars(self) if v.endswith("_") and getattr(self, v)]
  65. return bool(fitted_vars)
  66. def fit_status(self) -> "Preprocessor.FitStatus":
  67. if not self._is_fittable:
  68. return Preprocessor.FitStatus.NOT_FITTABLE
  69. elif (
  70. hasattr(self, "_fitted") and self._fitted
  71. ) or self._check_has_fitted_state():
  72. return Preprocessor.FitStatus.FITTED
  73. else:
  74. return Preprocessor.FitStatus.NOT_FITTED
  75. def fit(self, ds: "Dataset") -> "Preprocessor":
  76. """Fit this Preprocessor to the Dataset.
  77. Fitted state attributes will be directly set in the Preprocessor.
  78. Calling it more than once will overwrite all previously fitted state:
  79. ``preprocessor.fit(A).fit(B)`` is equivalent to ``preprocessor.fit(B)``.
  80. Args:
  81. ds: Input dataset.
  82. Returns:
  83. Preprocessor: The fitted Preprocessor with state attributes.
  84. """
  85. fit_status = self.fit_status()
  86. if fit_status == Preprocessor.FitStatus.NOT_FITTABLE:
  87. # No-op as there is no state to be fitted.
  88. return self
  89. if fit_status in (
  90. Preprocessor.FitStatus.FITTED,
  91. Preprocessor.FitStatus.PARTIALLY_FITTED,
  92. ):
  93. warnings.warn(
  94. "`fit` has already been called on the preprocessor (or at least one "
  95. "contained preprocessors if this is a chain). "
  96. "All previously fitted state will be overwritten!"
  97. )
  98. self.stat_computation_plan.reset()
  99. self.stats_ = {}
  100. fitted_ds = self._fit(ds)._fit_execute(ds)
  101. self._fitted = True
  102. return fitted_ds
  103. def _fit_execute(self, dataset: "Dataset"):
  104. self.stats_ |= self.stat_computation_plan.compute(dataset)
  105. return self
  106. def has_stats(self) -> bool:
  107. return hasattr(self, "stats_") and len(self.stats_) > 0
  108. def fit_transform(
  109. self,
  110. ds: "Dataset",
  111. *,
  112. transform_num_cpus: Optional[float] = None,
  113. transform_memory: Optional[float] = None,
  114. transform_batch_size: Optional[int] = None,
  115. transform_concurrency: Optional[int] = None,
  116. ) -> "Dataset":
  117. """Fit this Preprocessor to the Dataset and then transform the Dataset.
  118. Calling it more than once will overwrite all previously fitted state:
  119. ``preprocessor.fit_transform(A).fit_transform(B)``
  120. is equivalent to ``preprocessor.fit_transform(B)``.
  121. Args:
  122. ds: Input Dataset.
  123. transform_num_cpus: [experimental] The number of CPUs to reserve for each parallel map worker.
  124. transform_memory: [experimental] The heap memory in bytes to reserve for each parallel map worker.
  125. transform_batch_size: [experimental] The maximum number of rows to return.
  126. transform_concurrency: [experimental] The maximum number of Ray workers to use concurrently.
  127. Returns:
  128. ray.data.Dataset: The transformed Dataset.
  129. """
  130. self.fit(ds)
  131. return self.transform(
  132. ds,
  133. num_cpus=transform_num_cpus,
  134. memory=transform_memory,
  135. batch_size=transform_batch_size,
  136. concurrency=transform_concurrency,
  137. )
  138. def transform(
  139. self,
  140. ds: "Dataset",
  141. *,
  142. batch_size: Optional[int] = None,
  143. num_cpus: Optional[float] = None,
  144. memory: Optional[float] = None,
  145. concurrency: Optional[int] = None,
  146. ) -> "Dataset":
  147. """Transform the given dataset.
  148. Args:
  149. ds: Input Dataset.
  150. batch_size: [experimental] Advanced configuration for adjusting input size for each worker.
  151. num_cpus: [experimental] The number of CPUs to reserve for each parallel map worker.
  152. memory: [experimental] The heap memory in bytes to reserve for each parallel map worker.
  153. concurrency: [experimental] The maximum number of Ray workers to use concurrently.
  154. Returns:
  155. ray.data.Dataset: The transformed Dataset.
  156. Raises:
  157. PreprocessorNotFittedException: if ``fit`` is not called yet.
  158. """
  159. fit_status = self.fit_status()
  160. if fit_status in (
  161. Preprocessor.FitStatus.PARTIALLY_FITTED,
  162. Preprocessor.FitStatus.NOT_FITTED,
  163. ):
  164. raise PreprocessorNotFittedException(
  165. "`fit` must be called before `transform`, "
  166. "or simply use fit_transform() to run both steps"
  167. )
  168. transformed_ds = self._transform(
  169. ds,
  170. batch_size=batch_size,
  171. num_cpus=num_cpus,
  172. memory=memory,
  173. concurrency=concurrency,
  174. )
  175. return transformed_ds
  176. def transform_batch(self, data: "DataBatchType") -> "DataBatchType":
  177. """Transform a single batch of data.
  178. The data will be converted to the format supported by the Preprocessor,
  179. based on which ``_transform_*`` methods are implemented.
  180. Args:
  181. data: Input data batch.
  182. Returns:
  183. DataBatchType:
  184. The transformed data batch. This may differ
  185. from the input type depending on which ``_transform_*`` methods
  186. are implemented.
  187. """
  188. fit_status = self.fit_status()
  189. if fit_status in (
  190. Preprocessor.FitStatus.PARTIALLY_FITTED,
  191. Preprocessor.FitStatus.NOT_FITTED,
  192. ):
  193. raise PreprocessorNotFittedException(
  194. "`fit` must be called before `transform_batch`."
  195. )
  196. return self._transform_batch(data)
  197. @DeveloperAPI
  198. def _fit(self, ds: "Dataset") -> "Preprocessor":
  199. """Sub-classes should override this instead of fit()."""
  200. raise NotImplementedError()
  201. def _determine_transform_to_use(self) -> BatchFormat:
  202. """Determine which batch format to use based on Preprocessor implementation.
  203. * If only `_transform_pandas` is implemented, then use ``pandas`` batch format.
  204. * If only `_transform_numpy` is implemented, then use ``numpy`` batch format.
  205. * If only `_transform_arrow` is implemented, then use ``arrow`` batch format.
  206. * If multiple are implemented, then use the Preprocessor defined preferred batch
  207. format.
  208. """
  209. has_transform_pandas = (
  210. self.__class__._transform_pandas != Preprocessor._transform_pandas
  211. )
  212. has_transform_numpy = (
  213. self.__class__._transform_numpy != Preprocessor._transform_numpy
  214. )
  215. has_transform_arrow = (
  216. self.__class__._transform_arrow != Preprocessor._transform_arrow
  217. )
  218. num_transforms = sum(
  219. [
  220. has_transform_pandas,
  221. has_transform_numpy,
  222. has_transform_arrow,
  223. ]
  224. )
  225. if num_transforms > 1:
  226. return self.preferred_batch_format()
  227. elif has_transform_arrow:
  228. return BatchFormat.ARROW
  229. elif has_transform_numpy:
  230. return BatchFormat.NUMPY
  231. elif has_transform_pandas:
  232. return BatchFormat.PANDAS
  233. else:
  234. raise NotImplementedError(
  235. "None of `_transform_numpy`, `_transform_pandas` or `_transform_arrow` "
  236. "are implemented. At least one of these transform functions must be "
  237. "implemented for Preprocessor transforms."
  238. )
  239. def _transform(
  240. self,
  241. ds: "Dataset",
  242. batch_size: Optional[int],
  243. num_cpus: Optional[float] = None,
  244. memory: Optional[float] = None,
  245. concurrency: Optional[int] = None,
  246. ) -> "Dataset":
  247. transform_type = self._determine_transform_to_use()
  248. # Our user-facing batch format should only be pandas or NumPy, other
  249. # formats {arrow, simple} are internal.
  250. kwargs = self._get_transform_config()
  251. if num_cpus is not None:
  252. kwargs["num_cpus"] = num_cpus
  253. if memory is not None:
  254. kwargs["memory"] = memory
  255. if batch_size is not None:
  256. kwargs["batch_size"] = batch_size
  257. if concurrency is not None:
  258. kwargs["concurrency"] = concurrency
  259. if transform_type == BatchFormat.PANDAS:
  260. return ds.map_batches(
  261. self._transform_pandas,
  262. batch_format=BatchFormat.PANDAS,
  263. zero_copy_batch=True,
  264. **kwargs,
  265. )
  266. elif transform_type == BatchFormat.NUMPY:
  267. return ds.map_batches(
  268. self._transform_numpy,
  269. batch_format=BatchFormat.NUMPY,
  270. zero_copy_batch=True,
  271. **kwargs,
  272. )
  273. elif transform_type == BatchFormat.ARROW:
  274. return ds.map_batches(
  275. self._transform_arrow,
  276. batch_format="pyarrow",
  277. zero_copy_batch=True,
  278. **kwargs,
  279. )
  280. else:
  281. raise ValueError(
  282. "Invalid transform type returned from _determine_transform_to_use; "
  283. f'"pandas" and "numpy" allowed, but got: {transform_type}'
  284. )
  285. def _get_transform_config(self) -> Dict[str, Any]:
  286. """Returns kwargs to be passed to :meth:`ray.data.Dataset.map_batches`.
  287. This can be implemented by subclassing preprocessors.
  288. """
  289. return {}
  290. def _transform_batch(self, data: "DataBatchType") -> "DataBatchType":
  291. import numpy as np
  292. import pandas as pd
  293. from ray.data.util.data_batch_conversion import (
  294. _convert_batch_type_to_numpy,
  295. _convert_batch_type_to_pandas,
  296. )
  297. try:
  298. import pyarrow
  299. except ImportError:
  300. pyarrow = None
  301. if not isinstance(
  302. data, (pd.DataFrame, pyarrow.Table, collections.abc.Mapping, np.ndarray)
  303. ):
  304. raise ValueError(
  305. "`transform_batch` is currently only implemented for Pandas "
  306. "DataFrames, pyarrow Tables, NumPy ndarray and dictionary of "
  307. f"ndarray. Got {type(data)}."
  308. )
  309. transform_type = self._determine_transform_to_use()
  310. if transform_type == BatchFormat.PANDAS:
  311. return self._transform_pandas(_convert_batch_type_to_pandas(data))
  312. elif transform_type == BatchFormat.NUMPY:
  313. return self._transform_numpy(_convert_batch_type_to_numpy(data))
  314. elif transform_type == BatchFormat.ARROW:
  315. # Convert input to Arrow table and use Arrow transform
  316. input_was_pandas = isinstance(data, pd.DataFrame)
  317. if isinstance(data, pyarrow.Table):
  318. arrow_table = data
  319. elif input_was_pandas:
  320. arrow_table = pyarrow.Table.from_pandas(data)
  321. else:
  322. # Convert to pandas first, then to Arrow
  323. arrow_table = pyarrow.Table.from_pandas(
  324. _convert_batch_type_to_pandas(data)
  325. )
  326. result = self._transform_arrow(arrow_table)
  327. # Convert back to pandas if input was pandas
  328. if input_was_pandas and isinstance(result, pyarrow.Table):
  329. return result.to_pandas()
  330. return result
  331. @classmethod
  332. def _derive_and_validate_output_columns(
  333. cls, columns: List[str], output_columns: Optional[List[str]]
  334. ) -> List[str]:
  335. """Returns the output columns after validation.
  336. Checks if the columns are explicitly set, otherwise defaulting to
  337. the input columns.
  338. Raises:
  339. ValueError: If the length of the output columns does not match the
  340. length of the input columns.
  341. """
  342. if output_columns and len(columns) != len(output_columns):
  343. raise ValueError(
  344. "Invalid output_columns: Got len(columns) != len(output_columns). "
  345. "The length of columns and output_columns must match."
  346. )
  347. return output_columns or columns
  348. @DeveloperAPI
  349. def _transform_arrow(self, table: "pyarrow.Table") -> "pyarrow.Table":
  350. """Run the transformation on a data batch in a PyArrow Table format."""
  351. raise NotImplementedError()
  352. @DeveloperAPI
  353. def _transform_pandas(self, df: "pd.DataFrame") -> "pd.DataFrame":
  354. """Run the transformation on a data batch in a Pandas DataFrame format."""
  355. raise NotImplementedError()
  356. @DeveloperAPI
  357. def _transform_numpy(
  358. self, np_data: Union["np.ndarray", Dict[str, "np.ndarray"]]
  359. ) -> Union["np.ndarray", Dict[str, "np.ndarray"]]:
  360. """Run the transformation on a data batch in a NumPy ndarray format."""
  361. raise NotImplementedError()
  362. @classmethod
  363. @DeveloperAPI
  364. def preferred_batch_format(cls) -> BatchFormat:
  365. """Batch format hint for upstream producers to try yielding best block format.
  366. The preferred batch format to use if multiple transform methods
  367. (`_transform_pandas`, `_transform_numpy`, `_transform_arrow`) are implemented.
  368. Defaults to Pandas.
  369. Can be overridden by Preprocessor classes depending on which transform
  370. path is the most optimal.
  371. """
  372. return BatchFormat.PANDAS
  373. def get_input_columns(self) -> List[str]:
  374. return getattr(self, "columns", [])
  375. def get_output_columns(self) -> List[str]:
  376. return getattr(self, "output_columns", [])
  377. def __getstate__(self) -> Dict[str, Any]:
  378. state = self.__dict__.copy()
  379. # Exclude unpicklable attributes
  380. state.pop("stat_computation_plan", None)
  381. return state
  382. def __setstate__(self, state: Dict[str, Any]):
  383. from ray.data.preprocessors.utils import StatComputationPlan
  384. self.__dict__.update(state)
  385. self.stat_computation_plan = StatComputationPlan()
  386. @DeveloperAPI
  387. def serialize(self) -> str:
  388. """Return this preprocessor serialized as a string.
  389. Note: This is not a stable serialization format as it uses `pickle`.
  390. """
  391. # Convert it to a plain string so that it can be included as JSON metadata
  392. # in Trainer checkpoints.
  393. return base64.b64encode(pickle.dumps(self)).decode("ascii")
  394. @staticmethod
  395. @DeveloperAPI
  396. def deserialize(serialized: str) -> "Preprocessor":
  397. """Load the original preprocessor serialized via `self.serialize()`."""
  398. return pickle.loads(base64.b64decode(serialized))
  399. @DeveloperAPI
  400. class SerializablePreprocessorBase(Preprocessor, abc.ABC):
  401. """Abstract base class for serializable preprocessors.
  402. This class defines the serialization interface that all preprocessors must implement
  403. to support saving and loading their state. The serialization system uses CloudPickle
  404. as the primary format.
  405. **Architecture Overview:**
  406. The serialization system is built around two types of methods:
  407. 1. **Final Methods (DO NOT OVERRIDE):**
  408. - ``serialize()``: Orchestrates the serialization process
  409. - ``deserialize()``: Orchestrates the deserialization process
  410. These methods are marked as ``@final`` and should never be overridden by
  411. subclasses. They handle format detection, factory coordination, and error handling.
  412. 2. **Abstract Methods (MUST IMPLEMENT):**
  413. - ``_get_serializable_fields()``: Extract instance fields for serialization
  414. - ``_set_serializable_fields()``: Restore instance fields from deserialization
  415. - ``_get_stats()``: Extract computed statistics for serialization
  416. - ``_set_stats()``: Restore computed statistics from deserialization
  417. These methods must be implemented by each preprocessor subclass to define
  418. their specific serialization behavior.
  419. **Format Support:**
  420. - **CloudPickle** (default):
  421. - **Pickle** (legacy): Backward compatibility for existing serialized data
  422. **Important Notes:**
  423. - Never override ``serialize()`` or ``deserialize()`` in subclasses
  424. - Always call ``super().__init__()`` in subclass constructors
  425. - Use ``_fitted`` attribute to track fitting state
  426. - Store computed statistics in ``stats_`` dictionary
  427. - Handle version migration and backwards compatibility in ``_set_serializable_fields()`` if needed
  428. """
  429. @DeveloperAPI
  430. class SerializationFormat(Enum):
  431. CLOUDPICKLE = "cloudpickle"
  432. PICKLE = "pickle" # legacy
  433. MAGIC_CLOUDPICKLE = b"CPKL:"
  434. SERIALIZER_FORMAT_VERSION = 1
  435. @abc.abstractmethod
  436. def _get_serializable_fields(self) -> Dict[str, Any]:
  437. """Extract instance fields that should be serialized.
  438. This method should return a dictionary containing all instance attributes
  439. that are necessary to restore the preprocessor's configuration state.
  440. This typically includes constructor parameters and internal state flags.
  441. Returns:
  442. Dictionary mapping field names to their values
  443. """
  444. pass
  445. @abc.abstractmethod
  446. def _set_serializable_fields(self, fields: Dict[str, Any], version: int):
  447. """Restore instance fields from deserialized data.
  448. This method should restore the preprocessor's configuration state from
  449. the provided fields' dictionary. It's called during deserialization to
  450. recreate the instance state.
  451. **Version Migration:**
  452. If the serialized version differs from the current ``VERSION``,
  453. implement migration logic to handle schema changes:
  454. .. testcode::
  455. def _set_serializable_fields(self, fields: Dict[str, Any], version: int):
  456. # Handle version migration
  457. if version == 1 and self.VERSION == 2:
  458. # Migrate from version 1 to 2
  459. if "old_field" in fields:
  460. fields["new_field"] = migrate_old_field(fields.pop("old_field"))
  461. # Set all fields
  462. for key, value in fields.items():
  463. setattr(self, key, value)
  464. # Reinitialize derived state
  465. self.stat_computation_plan = StatComputationPlan()
  466. Args:
  467. fields: Dictionary of field names to values
  468. version: Version of the serialized data
  469. """
  470. pass
  471. def _get_stats(self) -> Dict[str, Any]:
  472. """Extract computed statistics that should be serialized.
  473. This method should return the computed statistics that were generated
  474. during the ``fit()`` process. These statistics are typically stored in
  475. the ``stats_`` attribute and contain the learned parameters needed for
  476. transformation.
  477. Returns:
  478. Dictionary containing computed statistics
  479. """
  480. return getattr(self, "stats_", {})
  481. def _set_stats(self, stats: Dict[str, Any]):
  482. """Restore computed statistics from deserialized data.
  483. This method should restore the preprocessor's computed statistics from
  484. the provided stats dictionary. These statistics are typically stored in
  485. the ``stats_`` attribute and contain learned parameters from fitting.
  486. Args:
  487. stats: Dictionary containing computed statistics
  488. """
  489. self.stats_ = stats
  490. @classmethod
  491. def get_preprocessor_class_id(cls) -> str:
  492. """Get the preprocessor class identifier for this preprocessor class.
  493. Returns:
  494. The preprocessor class identifier string used to identify this preprocessor
  495. type in serialized data.
  496. """
  497. return cls.__PREPROCESSOR_CLASS_ID
  498. @classmethod
  499. def set_preprocessor_class_id(cls, identifier: str) -> None:
  500. """Set the preprocessor class identifier for this preprocessor class.
  501. Args:
  502. identifier: The preprocessor class identifier string to use.
  503. """
  504. cls.__PREPROCESSOR_CLASS_ID = identifier
  505. @classmethod
  506. def get_version(cls) -> int:
  507. """Get the version number for this preprocessor class.
  508. Returns:
  509. The version number for this preprocessor's serialization format.
  510. """
  511. return cls.__VERSION
  512. @classmethod
  513. def set_version(cls, version: int) -> None:
  514. """Set the version number for this preprocessor class.
  515. Args:
  516. version: The version number for this preprocessor's serialization format.
  517. """
  518. cls.__VERSION = version
  519. @final
  520. @DeveloperAPI
  521. def serialize(self) -> Union[str, bytes]:
  522. """Serialize this preprocessor to a string or bytes.
  523. **⚠️ DO NOT OVERRIDE THIS METHOD IN SUBCLASSES ⚠️**
  524. This method is marked as ``@final`` in the concrete implementation and handles
  525. the complete serialization orchestration. Subclasses should implement the
  526. abstract methods instead: ``_get_serializable_fields()`` and ``_get_stats()``.
  527. **Serialization Process:**
  528. 1. Extracts fields via ``_get_serializable_fields()``
  529. 2. Extracts statistics via ``_get_stats()``
  530. 3. Packages data with metadata (type, version, format)
  531. 4. Delegates to ``SerializationHandlerFactory`` for format-specific handling
  532. 5. Returns serialized data with magic bytes for format identification
  533. **Supported Formats:**
  534. - **CloudPickle** (default):
  535. - **Pickle** (legacy): Backward compatibility for existing serialized data
  536. Returns:
  537. Serialized preprocessor data (bytes for CloudPickle, str for legacy Pickle)
  538. Raises:
  539. ValueError: If the serialization format is invalid or unsupported
  540. """
  541. # Lazy import to avoid circular dependency
  542. from ray.data.preprocessors.serialization_handlers import (
  543. HandlerFormatName,
  544. SerializationHandlerFactory,
  545. )
  546. # Prepare data for CloudPickle format
  547. data = {
  548. "type": self.get_preprocessor_class_id(),
  549. "version": self.get_version(),
  550. "fields": self._get_serializable_fields(),
  551. "stats": self._get_stats(),
  552. # The `serializer_format_version` field is for versioning the structure of this
  553. # dictionary. It is separate from the preprocessor's own version and is not used currently.
  554. "serializer_format_version": self.SERIALIZER_FORMAT_VERSION,
  555. }
  556. return SerializationHandlerFactory.get_handler(
  557. format_identifier=HandlerFormatName.CLOUDPICKLE
  558. ).serialize(data)
  559. @final
  560. @staticmethod
  561. @DeveloperAPI
  562. def deserialize(serialized: Union[str, bytes]) -> "Preprocessor":
  563. """Deserialize a preprocessor from serialized data.
  564. **⚠️ DO NOT OVERRIDE THIS METHOD IN SUBCLASSES ⚠️**
  565. This method is marked as ``@final`` in the concrete implementation and handles
  566. the complete deserialization orchestration. Subclasses should implement the
  567. abstract methods instead: ``_set_serializable_fields()`` and ``_set_stats()``.
  568. **Deserialization Process:**
  569. 1. Detects format from magic bytes in serialized data
  570. 2. Delegates to ``SerializationHandlerFactory`` for format-specific parsing
  571. 3. Extracts metadata (type, version, fields, stats)
  572. 4. Looks up preprocessor class from registry
  573. 5. Creates new instance and restores state via abstract methods
  574. 6. Returns fully reconstructed preprocessor instance
  575. **Format Detection:**
  576. The method automatically detects the serialization format:
  577. - ``CPKL:`` → CloudPickle format
  578. - Base64 string → Legacy Pickle format
  579. **Error Handling:**
  580. Provides comprehensive error handling for:
  581. - Unknown serialization formats
  582. - Corrupted or invalid data
  583. - Missing preprocessor types
  584. - Version compatibility issues
  585. Args:
  586. serialized: Serialized preprocessor data (bytes or str)
  587. Returns:
  588. Reconstructed preprocessor instance
  589. Raises:
  590. ValueError: If the serialized data is corrupted or format is unrecognized
  591. UnknownPreprocessorError: If the preprocessor type is not registered
  592. """
  593. # Lazy imports to avoid circular dependency
  594. from ray.data.preprocessors.serialization_handlers import (
  595. PickleSerializationHandler,
  596. SerializationHandlerFactory,
  597. )
  598. from ray.data.preprocessors.version_support import (
  599. UnknownPreprocessorError,
  600. _lookup_class,
  601. )
  602. try:
  603. # Use factory to deserialize all formats (auto-detects format)
  604. handler = SerializationHandlerFactory.get_handler(data=serialized)
  605. meta = handler.deserialize(serialized)
  606. # Handle pickle specially - it returns the object directly
  607. if isinstance(handler, PickleSerializationHandler):
  608. return meta # For pickle, meta is actually the deserialized object
  609. # Reconstruct the preprocessor object for structured formats
  610. cls = _lookup_class(meta["type"])
  611. # Validate metadata
  612. if meta["serializer_format_version"] != cls.SERIALIZER_FORMAT_VERSION:
  613. raise ValueError(
  614. f"Unsupported serializer format version: {meta['serializer_format_version']}"
  615. )
  616. obj = cls.__new__(cls)
  617. # handle base class fields here
  618. from ray.data.preprocessors.utils import StatComputationPlan
  619. obj.stat_computation_plan = StatComputationPlan()
  620. obj._set_serializable_fields(fields=meta["fields"], version=meta["version"])
  621. obj._set_stats(stats=meta["stats"])
  622. return obj
  623. except UnknownPreprocessorError:
  624. # Let UnknownPreprocessorError pass through unchanged for specific error handling
  625. raise
  626. except Exception as e:
  627. # Provide more helpful error message for other exception types
  628. raise ValueError(
  629. f"Failed to deserialize preprocessor. Data preview: {serialized[:50]}..."
  630. ) from e
  631. SerializationFormat = SerializablePreprocessorBase.SerializationFormat