chain.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. from typing import TYPE_CHECKING, Optional
  2. from ray.data.preprocessor import Preprocessor
  3. from ray.data.util.data_batch_conversion import BatchFormat
  4. if TYPE_CHECKING:
  5. from ray.air.data_batch_type import DataBatchType
  6. from ray.data.dataset import Dataset
  7. class Chain(Preprocessor):
  8. """Combine multiple preprocessors into a single :py:class:`Preprocessor`.
  9. When you call ``fit``, each preprocessor is fit on the dataset produced by the
  10. preceeding preprocessor's ``fit_transform``.
  11. Example:
  12. >>> import pandas as pd
  13. >>> import ray
  14. >>> from ray.data.preprocessors import *
  15. >>>
  16. >>> df = pd.DataFrame({
  17. ... "X0": [0, 1, 2],
  18. ... "X1": [3, 4, 5],
  19. ... "Y": ["orange", "blue", "orange"],
  20. ... })
  21. >>> ds = ray.data.from_pandas(df) # doctest: +SKIP
  22. >>>
  23. >>> preprocessor = Chain(
  24. ... StandardScaler(columns=["X0", "X1"]),
  25. ... Concatenator(columns=["X0", "X1"], output_column_name="X"),
  26. ... LabelEncoder(label_column="Y")
  27. ... )
  28. >>> preprocessor.fit_transform(ds).to_pandas() # doctest: +SKIP
  29. Y X
  30. 0 1 [-1.224744871391589, -1.224744871391589]
  31. 1 0 [0.0, 0.0]
  32. 2 1 [1.224744871391589, 1.224744871391589]
  33. Args:
  34. preprocessors: The preprocessors to sequentially compose.
  35. """
  36. def fit_status(self):
  37. fittable_count = 0
  38. fitted_count = 0
  39. for p in self.preprocessors:
  40. if p.fit_status() == Preprocessor.FitStatus.FITTED:
  41. fittable_count += 1
  42. fitted_count += 1
  43. elif p.fit_status() in (
  44. Preprocessor.FitStatus.NOT_FITTED,
  45. Preprocessor.FitStatus.PARTIALLY_FITTED,
  46. ):
  47. fittable_count += 1
  48. else:
  49. assert p.fit_status() == Preprocessor.FitStatus.NOT_FITTABLE
  50. if fittable_count > 0:
  51. if fitted_count == fittable_count:
  52. return Preprocessor.FitStatus.FITTED
  53. elif fitted_count > 0:
  54. return Preprocessor.FitStatus.PARTIALLY_FITTED
  55. else:
  56. return Preprocessor.FitStatus.NOT_FITTED
  57. else:
  58. return Preprocessor.FitStatus.NOT_FITTABLE
  59. def __init__(self, *preprocessors: Preprocessor):
  60. super().__init__()
  61. self.preprocessors = preprocessors
  62. def _fit(self, ds: "Dataset") -> Preprocessor:
  63. for preprocessor in self.preprocessors[:-1]:
  64. ds = preprocessor.fit_transform(ds)
  65. self.preprocessors[-1].fit(ds)
  66. return self
  67. def fit_transform(self, ds: "Dataset") -> "Dataset":
  68. for preprocessor in self.preprocessors:
  69. ds = preprocessor.fit_transform(ds)
  70. return ds
  71. def _transform(
  72. self,
  73. ds: "Dataset",
  74. batch_size: Optional[int],
  75. num_cpus: Optional[float] = None,
  76. memory: Optional[float] = None,
  77. concurrency: Optional[int] = None,
  78. ) -> "Dataset":
  79. for preprocessor in self.preprocessors:
  80. ds = preprocessor.transform(
  81. ds,
  82. batch_size=batch_size,
  83. num_cpus=num_cpus,
  84. memory=memory,
  85. concurrency=concurrency,
  86. )
  87. return ds
  88. def _transform_batch(self, df: "DataBatchType") -> "DataBatchType":
  89. for preprocessor in self.preprocessors:
  90. df = preprocessor.transform_batch(df)
  91. return df
  92. def __repr__(self):
  93. arguments = ", ".join(repr(preprocessor) for preprocessor in self.preprocessors)
  94. return f"{self.__class__.__name__}({arguments})"
  95. def _determine_transform_to_use(self) -> BatchFormat:
  96. # This is relevant for BatchPrediction.
  97. # For Chain preprocessor, we picked the first one as entry point.
  98. # TODO (jiaodong): We should revisit if our Chain preprocessor is
  99. # still optimal with context of lazy execution.
  100. return self.preprocessors[0]._determine_transform_to_use()