imputer.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. import logging
  2. from collections import Counter
  3. from numbers import Number
  4. from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
  5. import numpy as np
  6. import pandas as pd
  7. from pandas.api.types import is_categorical_dtype
  8. from ray.data.aggregate import Mean
  9. from ray.data.preprocessor import SerializablePreprocessorBase
  10. from ray.data.preprocessors.version_support import (
  11. SerializablePreprocessor as Serializable,
  12. )
  13. from ray.util.annotations import PublicAPI
  14. if TYPE_CHECKING:
  15. from ray.data.dataset import Dataset
  16. logger = logging.getLogger(__name__)
  17. @PublicAPI(stability="alpha")
  18. @Serializable(version=1, identifier="io.ray.preprocessors.simple_imputer")
  19. class SimpleImputer(SerializablePreprocessorBase):
  20. """Replace missing values with imputed values. If the column is missing from a
  21. batch, it will be filled with the imputed value.
  22. Examples:
  23. >>> import pandas as pd
  24. >>> import ray
  25. >>> from ray.data.preprocessors import SimpleImputer
  26. >>> df = pd.DataFrame({"X": [0, None, 3, 3], "Y": [None, "b", "c", "c"]})
  27. >>> ds = ray.data.from_pandas(df) # doctest: +SKIP
  28. >>> ds.to_pandas() # doctest: +SKIP
  29. X Y
  30. 0 0.0 None
  31. 1 NaN b
  32. 2 3.0 c
  33. 3 3.0 c
  34. The `"mean"` strategy imputes missing values with the mean of non-missing
  35. values. This strategy doesn't work with categorical data.
  36. >>> preprocessor = SimpleImputer(columns=["X"], strategy="mean")
  37. >>> preprocessor.fit_transform(ds).to_pandas() # doctest: +SKIP
  38. X Y
  39. 0 0.0 None
  40. 1 2.0 b
  41. 2 3.0 c
  42. 3 3.0 c
  43. The `"most_frequent"` strategy imputes missing values with the most frequent
  44. value in each column.
  45. >>> preprocessor = SimpleImputer(columns=["X", "Y"], strategy="most_frequent")
  46. >>> preprocessor.fit_transform(ds).to_pandas() # doctest: +SKIP
  47. X Y
  48. 0 0.0 c
  49. 1 3.0 b
  50. 2 3.0 c
  51. 3 3.0 c
  52. The `"constant"` strategy imputes missing values with the value specified by
  53. `fill_value`.
  54. >>> preprocessor = SimpleImputer(
  55. ... columns=["Y"],
  56. ... strategy="constant",
  57. ... fill_value="?",
  58. ... )
  59. >>> preprocessor.fit_transform(ds).to_pandas() # doctest: +SKIP
  60. X Y
  61. 0 0.0 ?
  62. 1 NaN b
  63. 2 3.0 c
  64. 3 3.0 c
  65. :class:`SimpleImputer` can also be used in append mode by providing the
  66. name of the output_columns that should hold the imputed values.
  67. >>> preprocessor = SimpleImputer(columns=["X"], output_columns=["X_imputed"], strategy="mean")
  68. >>> preprocessor.fit_transform(ds).to_pandas() # doctest: +SKIP
  69. X Y X_imputed
  70. 0 0.0 None 0.0
  71. 1 NaN b 2.0
  72. 2 3.0 c 3.0
  73. 3 3.0 c 3.0
  74. Args:
  75. columns: The columns to apply imputation to.
  76. strategy: How imputed values are chosen.
  77. * ``"mean"``: The mean of non-missing values. This strategy only works with numeric columns.
  78. * ``"most_frequent"``: The most common value.
  79. * ``"constant"``: The value passed to ``fill_value``.
  80. fill_value: The value to use when ``strategy`` is ``"constant"``.
  81. output_columns: The names of the transformed columns. If None, the transformed
  82. columns will be the same as the input columns. If not None, the length of
  83. ``output_columns`` must match the length of ``columns``, othwerwise an error
  84. will be raised.
  85. Raises:
  86. ValueError: if ``strategy`` is not ``"mean"``, ``"most_frequent"``, or
  87. ``"constant"``.
  88. """ # noqa: E501
  89. _valid_strategies = ["mean", "most_frequent", "constant"]
  90. def __init__(
  91. self,
  92. columns: List[str],
  93. strategy: str = "mean",
  94. fill_value: Optional[Union[str, Number]] = None,
  95. *,
  96. output_columns: Optional[List[str]] = None,
  97. ):
  98. super().__init__()
  99. self.columns = columns
  100. self.strategy = strategy
  101. self.fill_value = fill_value
  102. if strategy not in self._valid_strategies:
  103. raise ValueError(
  104. f"Strategy {strategy} is not supported."
  105. f"Supported values are: {self._valid_strategies}"
  106. )
  107. if strategy == "constant":
  108. # There is no information to be fitted.
  109. self._is_fittable = False
  110. if fill_value is None:
  111. raise ValueError(
  112. '`fill_value` must be set when using "constant" strategy.'
  113. )
  114. self.output_columns = (
  115. SerializablePreprocessorBase._derive_and_validate_output_columns(
  116. columns, output_columns
  117. )
  118. )
  119. def _fit(self, dataset: "Dataset") -> SerializablePreprocessorBase:
  120. if self.strategy == "mean":
  121. self.stat_computation_plan.add_aggregator(
  122. aggregator_fn=Mean, columns=self.columns
  123. )
  124. elif self.strategy == "most_frequent":
  125. self.stat_computation_plan.add_callable_stat(
  126. stat_fn=lambda key_gen: _get_most_frequent_values(
  127. dataset=dataset,
  128. columns=self.columns,
  129. key_gen=key_gen,
  130. ),
  131. stat_key_fn=lambda col: f"most_frequent({col})",
  132. columns=self.columns,
  133. )
  134. return self
  135. def _transform_pandas(self, df: pd.DataFrame):
  136. for column, output_column in zip(self.columns, self.output_columns):
  137. value = self._get_fill_value(column)
  138. if value is None:
  139. raise ValueError(
  140. f"Column {column} has no fill value. "
  141. "Check the data used to fit the SimpleImputer."
  142. )
  143. if column not in df.columns:
  144. # Create the column with the fill_value if it doesn't exist
  145. df[output_column] = value
  146. else:
  147. if is_categorical_dtype(df.dtypes[column]):
  148. df[output_column] = df[column].cat.add_categories([value])
  149. if (
  150. output_column != column
  151. # If the backing array is memory-mapped from shared memory, then the
  152. # array won't be writeable.
  153. or (
  154. isinstance(df[output_column].values, np.ndarray)
  155. and not df[output_column].values.flags.writeable
  156. )
  157. ):
  158. df[output_column] = df[column].copy(deep=True)
  159. df.fillna({output_column: value}, inplace=True)
  160. return df
  161. def _get_fill_value(self, column):
  162. if self.strategy == "mean":
  163. return self.stats_[f"mean({column})"]
  164. elif self.strategy == "most_frequent":
  165. return self.stats_[f"most_frequent({column})"]
  166. elif self.strategy == "constant":
  167. return self.fill_value
  168. else:
  169. raise ValueError(
  170. f"Strategy {self.strategy} is not supported. "
  171. "Supported values are: {self._valid_strategies}"
  172. )
  173. def __repr__(self):
  174. return (
  175. f"{self.__class__.__name__}(columns={self.columns!r}, "
  176. f"strategy={self.strategy!r}, fill_value={self.fill_value!r}, "
  177. f"output_columns={self.output_columns!r})"
  178. )
  179. def _get_serializable_fields(self) -> Dict[str, Any]:
  180. return {
  181. "columns": self.columns,
  182. "output_columns": self.output_columns,
  183. "_fitted": getattr(self, "_fitted", None),
  184. "strategy": self.strategy,
  185. "fill_value": getattr(self, "fill_value", None),
  186. }
  187. def _set_serializable_fields(self, fields: Dict[str, Any], version: int):
  188. # required fields
  189. self.columns = fields["columns"]
  190. self.output_columns = fields["output_columns"]
  191. self.strategy = fields["strategy"]
  192. # optional fields
  193. self._fitted = fields.get("_fitted")
  194. self.fill_value = fields.get("fill_value")
  195. if self.strategy == "constant":
  196. self._is_fittable = False
  197. def _get_most_frequent_values(
  198. dataset: "Dataset",
  199. columns: List[str],
  200. key_gen: Callable[[str], str],
  201. ) -> Dict[str, Union[str, Number]]:
  202. def get_pd_value_counts(df: pd.DataFrame) -> Dict[str, List[Counter]]:
  203. return {col: [Counter(df[col].value_counts().to_dict())] for col in columns}
  204. value_counts = dataset.map_batches(get_pd_value_counts, batch_format="pandas")
  205. final_counters = {col: Counter() for col in columns}
  206. for batch in value_counts.iter_batches(batch_size=None):
  207. for col, counters in batch.items():
  208. for counter in counters:
  209. final_counters[col] += counter
  210. return {
  211. key_gen(column): final_counters[column].most_common(1)[0][0] # noqa
  212. for column in columns
  213. }