predictor.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. import abc
  2. import warnings
  3. from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union
  4. import numpy as np
  5. import pandas as pd
  6. from ray.air.data_batch_type import DataBatchType
  7. from ray.air.util.data_batch_conversion import (
  8. BatchFormat,
  9. _convert_batch_type_to_numpy,
  10. _convert_batch_type_to_pandas,
  11. )
  12. from ray.train import Checkpoint
  13. from ray.util.annotations import Deprecated, DeveloperAPI, PublicAPI
  14. try:
  15. import pyarrow
  16. pa_table = pyarrow.Table
  17. except ImportError:
  18. pa_table = None
  19. if TYPE_CHECKING:
  20. from ray.data import Preprocessor
  21. # Reverse mapping from data batch type to batch format.
  22. TYPE_TO_ENUM: Dict[Type[DataBatchType], BatchFormat] = {
  23. np.ndarray: BatchFormat.NUMPY,
  24. dict: BatchFormat.NUMPY,
  25. pd.DataFrame: BatchFormat.PANDAS,
  26. }
  27. @PublicAPI(stability="beta")
  28. class PredictorNotSerializableException(RuntimeError):
  29. """Error raised when trying to serialize a Predictor instance."""
  30. pass
  31. @Deprecated
  32. class Predictor(abc.ABC):
  33. """Predictors load models from checkpoints to perform inference.
  34. .. note::
  35. The base ``Predictor`` class cannot be instantiated directly. Only one of
  36. its subclasses can be used.
  37. **How does a Predictor work?**
  38. Predictors expose a ``predict`` method that accepts an input batch of type
  39. ``DataBatchType`` and outputs predictions of the same type as the input batch.
  40. When the ``predict`` method is called the following occurs:
  41. - The input batch is converted into a pandas DataFrame. Tensor input (like a
  42. ``np.ndarray``) will be converted into a single column Pandas Dataframe.
  43. - If there is a :ref:`Preprocessor <preprocessor-ref>` saved in the provided
  44. :class:`Checkpoint <ray.train.Checkpoint>`, the preprocessor will be used to
  45. transform the DataFrame.
  46. - The transformed DataFrame will be passed to the model for inference (via the
  47. ``predictor._predict_pandas`` method).
  48. - The predictions will be outputted by ``predict`` in the same type as the
  49. original input.
  50. **How do I create a new Predictor?**
  51. To implement a new Predictor for your particular framework, you should subclass
  52. the base ``Predictor`` and implement the following two methods:
  53. 1. ``_predict_pandas``: Given a pandas.DataFrame input, return a
  54. pandas.DataFrame containing predictions.
  55. 2. ``from_checkpoint``: Logic for creating a Predictor from a
  56. :class:`Checkpoint <ray.train.Checkpoint>`.
  57. 3. Optionally ``_predict_numpy`` for better performance when working with
  58. tensor data to avoid extra copies from Pandas conversions.
  59. """
  60. def __init__(self, preprocessor: Optional["Preprocessor"] = None):
  61. """Subclasseses must call Predictor.__init__() to set a preprocessor."""
  62. warnings.warn(
  63. f"{self.__class__.__name__} is deprecated and will be removed after April 2026.",
  64. DeprecationWarning,
  65. stacklevel=2,
  66. )
  67. self._preprocessor: Optional[Preprocessor] = preprocessor
  68. # Whether tensor columns should be automatically cast from/to the tensor
  69. # extension type at UDF boundaries. This can be overridden by subclasses.
  70. self._cast_tensor_columns = False
  71. @classmethod
  72. @abc.abstractmethod
  73. def from_checkpoint(cls, checkpoint: Checkpoint, **kwargs) -> "Predictor":
  74. """Create a specific predictor from a checkpoint.
  75. Args:
  76. checkpoint: Checkpoint to load predictor data from.
  77. kwargs: Arguments specific to predictor implementations.
  78. Returns:
  79. Predictor: Predictor object.
  80. """
  81. raise NotImplementedError
  82. @classmethod
  83. def from_pandas_udf(
  84. cls, pandas_udf: Callable[[pd.DataFrame], pd.DataFrame]
  85. ) -> "Predictor":
  86. """Create a Predictor from a Pandas UDF.
  87. Args:
  88. pandas_udf: A function that takes a pandas.DataFrame and other
  89. optional kwargs and returns a pandas.DataFrame.
  90. """
  91. class PandasUDFPredictor(Predictor):
  92. @classmethod
  93. def from_checkpoint(cls, checkpoint: Checkpoint, **kwargs) -> "Predictor":
  94. return PandasUDFPredictor()
  95. def _predict_pandas(self, df, **kwargs) -> "pd.DataFrame":
  96. return pandas_udf(df, **kwargs)
  97. return PandasUDFPredictor()
  98. def get_preprocessor(self) -> Optional["Preprocessor"]:
  99. """Get the preprocessor to use prior to executing predictions."""
  100. return self._preprocessor
  101. def set_preprocessor(self, preprocessor: Optional["Preprocessor"]) -> None:
  102. """Set the preprocessor to use prior to executing predictions."""
  103. self._preprocessor = preprocessor
  104. @classmethod
  105. @DeveloperAPI
  106. def preferred_batch_format(cls) -> BatchFormat:
  107. """Batch format hint for upstream producers to try yielding best block format.
  108. The preferred batch format to use if both `_predict_pandas` and
  109. `_predict_numpy` are implemented. Defaults to Pandas.
  110. Can be overridden by predictor classes depending on the framework type,
  111. e.g. TorchPredictor prefers Numpy and XGBoostPredictor prefers Pandas as
  112. native batch format.
  113. """
  114. return BatchFormat.PANDAS
  115. @classmethod
  116. def _batch_format_to_use(cls) -> BatchFormat:
  117. """Determine the batch format to use for the predictor."""
  118. has_pandas_implemented = cls._predict_pandas != Predictor._predict_pandas
  119. has_numpy_implemented = cls._predict_numpy != Predictor._predict_numpy
  120. if has_pandas_implemented and has_numpy_implemented:
  121. return cls.preferred_batch_format()
  122. elif has_pandas_implemented:
  123. return BatchFormat.PANDAS
  124. elif has_numpy_implemented:
  125. return BatchFormat.NUMPY
  126. else:
  127. raise NotImplementedError(
  128. f"Predictor {cls.__name__} must implement at least one of "
  129. "`_predict_pandas` and `_predict_numpy`."
  130. )
  131. def _set_cast_tensor_columns(self):
  132. """Enable automatic tensor column casting.
  133. If this is called on a predictor, the predictor will cast tensor columns to
  134. NumPy ndarrays in the input to the preprocessors and cast tensor columns back to
  135. the tensor extension type in the prediction outputs.
  136. """
  137. self._cast_tensor_columns = True
  138. def predict(self, data: DataBatchType, **kwargs) -> DataBatchType:
  139. """Perform inference on a batch of data.
  140. Args:
  141. data: A batch of input data of type ``DataBatchType``.
  142. kwargs: Arguments specific to predictor implementations. These are passed
  143. directly to ``_predict_numpy`` or ``_predict_pandas``.
  144. Returns:
  145. DataBatchType:
  146. Prediction result. The return type will be the same as the input type.
  147. """
  148. if not hasattr(self, "_preprocessor"):
  149. raise NotImplementedError(
  150. "Subclasses of Predictor must call Predictor.__init__(preprocessor)."
  151. )
  152. try:
  153. batch_format = TYPE_TO_ENUM[type(data)]
  154. except KeyError:
  155. raise RuntimeError(
  156. f"Invalid input data type of {type(data)}, supported "
  157. f"types: {list(TYPE_TO_ENUM.keys())}"
  158. )
  159. if self._preprocessor:
  160. data = self._preprocessor.transform_batch(data)
  161. batch_format_to_use = self._batch_format_to_use()
  162. # We can finish prediction as long as one predict method is implemented.
  163. # For prediction, we have to return back in the same format as the input.
  164. if batch_format == BatchFormat.PANDAS:
  165. if batch_format_to_use == BatchFormat.PANDAS:
  166. return self._predict_pandas(
  167. _convert_batch_type_to_pandas(data), **kwargs
  168. )
  169. elif batch_format_to_use == BatchFormat.NUMPY:
  170. return _convert_batch_type_to_pandas(
  171. self._predict_numpy(_convert_batch_type_to_numpy(data), **kwargs)
  172. )
  173. elif batch_format == BatchFormat.NUMPY:
  174. if batch_format_to_use == BatchFormat.PANDAS:
  175. return _convert_batch_type_to_numpy(
  176. self._predict_pandas(_convert_batch_type_to_pandas(data), **kwargs)
  177. )
  178. elif batch_format_to_use == BatchFormat.NUMPY:
  179. return self._predict_numpy(_convert_batch_type_to_numpy(data), **kwargs)
  180. @DeveloperAPI
  181. def _predict_pandas(self, data: "pd.DataFrame", **kwargs) -> "pd.DataFrame":
  182. """Perform inference on a Pandas DataFrame.
  183. Args:
  184. data: A pandas DataFrame to perform predictions on.
  185. kwargs: Arguments specific to the predictor implementation.
  186. Returns:
  187. A pandas DataFrame containing the prediction result.
  188. """
  189. raise NotImplementedError
  190. @DeveloperAPI
  191. def _predict_numpy(
  192. self, data: Union[np.ndarray, Dict[str, np.ndarray]], **kwargs
  193. ) -> Union[np.ndarray, Dict[str, np.ndarray]]:
  194. """Perform inference on a Numpy data.
  195. All Predictors working with tensor data (like deep learning predictors)
  196. should implement this method.
  197. Args:
  198. data: A Numpy ndarray or dictionary of ndarrays to perform predictions on.
  199. kwargs: Arguments specific to the predictor implementation.
  200. Returns:
  201. A Numpy ndarray or dictionary of ndarray containing the prediction result.
  202. """
  203. raise NotImplementedError
  204. def __reduce__(self):
  205. raise PredictorNotSerializableException(
  206. "Predictor instances are not serializable. Instead, you may want "
  207. "to serialize a checkpoint and initialize the Predictor with "
  208. "Predictor.from_checkpoint."
  209. )