concatenator.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. import logging
  2. from typing import Any, Dict, List, Optional
  3. import numpy as np
  4. import pandas as pd
  5. from ray.data.preprocessor import Preprocessor
  6. from ray.util.annotations import PublicAPI
  7. logger = logging.getLogger(__name__)
  8. @PublicAPI(stability="alpha")
  9. class Concatenator(Preprocessor):
  10. """Combine numeric columns into a column of type
  11. :class:`~ray.data._internal.tensor_extensions.pandas.TensorDtype`. Only columns
  12. specified in ``columns`` will be concatenated.
  13. This preprocessor concatenates numeric columns and stores the result in a new
  14. column. The new column contains
  15. :class:`~ray.data._internal.tensor_extensions.pandas.TensorArrayElement` objects of
  16. shape :math:`(m,)`, where :math:`m` is the number of columns concatenated.
  17. The :math:`m` concatenated columns are dropped after concatenation.
  18. The preprocessor preserves the order of the columns provided in the ``colummns``
  19. argument and will use that order when calling ``transform()`` and ``transform_batch()``.
  20. Examples:
  21. >>> import numpy as np
  22. >>> import pandas as pd
  23. >>> import ray
  24. >>> from ray.data.preprocessors import Concatenator
  25. :py:class:`Concatenator` combines numeric columns into a column of
  26. :py:class:`~ray.data._internal.tensor_extensions.pandas.TensorDtype`.
  27. >>> df = pd.DataFrame({"X0": [0, 3, 1], "X1": [0.5, 0.2, 0.9]})
  28. >>> ds = ray.data.from_pandas(df) # doctest: +SKIP
  29. >>> concatenator = Concatenator(columns=["X0", "X1"])
  30. >>> concatenator.transform(ds).to_pandas() # doctest: +SKIP
  31. concat_out
  32. 0 [0.0, 0.5]
  33. 1 [3.0, 0.2]
  34. 2 [1.0, 0.9]
  35. By default, the created column is called `"concat_out"`, but you can specify
  36. a different name.
  37. >>> concatenator = Concatenator(columns=["X0", "X1"], output_column_name="tensor")
  38. >>> concatenator.transform(ds).to_pandas() # doctest: +SKIP
  39. tensor
  40. 0 [0.0, 0.5]
  41. 1 [3.0, 0.2]
  42. 2 [1.0, 0.9]
  43. >>> concatenator = Concatenator(columns=["X0", "X1"], dtype=np.float32)
  44. >>> concatenator.transform(ds) # doctest: +SKIP
  45. Dataset(num_rows=3, schema={Y: object, concat_out: TensorDtype(shape=(2,), dtype=float32)})
  46. When ``flatten=True``, nested vectors in the columns will be flattened during concatenation:
  47. >>> df = pd.DataFrame({"X0": [[1, 2], [3, 4]], "X1": [0.5, 0.2]})
  48. >>> ds = ray.data.from_pandas(df) # doctest: +SKIP
  49. >>> concatenator = Concatenator(columns=["X0", "X1"], flatten=True)
  50. >>> concatenator.transform(ds).to_pandas() # doctest: +SKIP
  51. concat_out
  52. 0 [1.0, 2.0, 0.5]
  53. 1 [3.0, 4.0, 0.2]
  54. Args:
  55. columns: A list of columns to concatenate. The provided order of the columns
  56. will be retained during concatenation.
  57. output_column_name: The desired name for the new column.
  58. Defaults to ``"concat_out"``.
  59. dtype: The ``dtype`` to convert the output tensors to. If unspecified,
  60. the ``dtype`` is determined by standard coercion rules.
  61. raise_if_missing: If ``True``, an error is raised if any
  62. of the columns in ``columns`` don't exist.
  63. Defaults to ``False``.
  64. flatten: If ``True``, nested vectors in the columns will be flattened during
  65. concatenation. Defaults to ``False``.
  66. Raises:
  67. ValueError: if `raise_if_missing` is `True` and a column in `columns` or
  68. doesn't exist in the dataset.
  69. """ # noqa: E501
  70. _is_fittable = False
  71. def __init__(
  72. self,
  73. columns: List[str],
  74. output_column_name: str = "concat_out",
  75. dtype: Optional[np.dtype] = None,
  76. raise_if_missing: bool = False,
  77. flatten: bool = False,
  78. ):
  79. super().__init__()
  80. self.columns = columns
  81. self.output_column_name = output_column_name
  82. self.dtype = dtype
  83. self.raise_if_missing = raise_if_missing
  84. self.flatten = flatten
  85. def _validate(self, df: pd.DataFrame) -> None:
  86. missing_columns = set(self.columns) - set(df)
  87. if missing_columns:
  88. message = (
  89. f"Missing columns specified in '{self.columns}': {missing_columns}"
  90. )
  91. if self.raise_if_missing:
  92. raise ValueError(message)
  93. else:
  94. logger.warning(message)
  95. def _transform_pandas(self, df: pd.DataFrame):
  96. self._validate(df)
  97. if self.flatten:
  98. concatenated = df[self.columns].to_numpy()
  99. concatenated = [
  100. np.concatenate(
  101. [
  102. np.atleast_1d(elem)
  103. if self.dtype is None
  104. else np.atleast_1d(elem).astype(self.dtype)
  105. for elem in row
  106. ]
  107. )
  108. for row in concatenated
  109. ]
  110. else:
  111. concatenated = df[self.columns].to_numpy(dtype=self.dtype)
  112. df = df.drop(columns=self.columns)
  113. # Use a Pandas Series for column assignment to get more consistent
  114. # behavior across Pandas versions.
  115. df.loc[:, self.output_column_name] = pd.Series(list(concatenated))
  116. return df
  117. def get_input_columns(self) -> List[str]:
  118. return self.columns
  119. def get_output_columns(self) -> List[str]:
  120. return [self.output_column_name]
  121. def __repr__(self):
  122. default_values = {
  123. "output_column_name": "concat_out",
  124. "columns": None,
  125. "dtype": None,
  126. "raise_if_missing": False,
  127. "flatten": False,
  128. }
  129. non_default_arguments = []
  130. for parameter, default_value in default_values.items():
  131. value = getattr(self, parameter)
  132. if value != default_value:
  133. non_default_arguments.append(f"{parameter}={value}")
  134. return f"{self.__class__.__name__}({', '.join(non_default_arguments)})"
  135. def __setstate__(self, state: Dict[str, Any]) -> None:
  136. super().__setstate__(state)
  137. # flatten is a recent field, to ensure backwards compatibility
  138. # assign a default in case it is missing in the serialized state
  139. if not hasattr(self, "flatten"):
  140. self.flatten = False