encoder.py 47 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242
  1. import logging
  2. from collections import Counter
  3. from functools import partial
  4. from typing import (
  5. TYPE_CHECKING,
  6. Any,
  7. Callable,
  8. Dict,
  9. Hashable,
  10. List,
  11. Optional,
  12. Set,
  13. Tuple,
  14. Union,
  15. )
  16. import numpy as np
  17. import pandas as pd
  18. import pandas.api.types
  19. import pyarrow as pa
  20. import pyarrow.compute as pc
  21. from ray.data._internal.util import is_null
  22. from ray.data.block import BlockAccessor
  23. from ray.data.preprocessor import (
  24. Preprocessor,
  25. PreprocessorNotFittedException,
  26. SerializablePreprocessorBase,
  27. )
  28. from ray.data.preprocessors.utils import (
  29. make_post_processor,
  30. )
  31. from ray.data.preprocessors.version_support import SerializablePreprocessor
  32. from ray.data.util.data_batch_conversion import BatchFormat
  33. from ray.util.annotations import DeveloperAPI, PublicAPI
  34. if TYPE_CHECKING:
  35. from ray.data.dataset import Dataset
  36. logger = logging.getLogger(__name__)
  37. def _get_unique_value_arrow_arrays(
  38. stats: Dict[str, Any], input_col: str
  39. ) -> Tuple[pa.Array, pa.Array]:
  40. """Get Arrow arrays for keys and values from encoder stats.
  41. Args:
  42. stats: The encoder's stats_ dictionary.
  43. input_col: The name of the column to get arrays for.
  44. Returns:
  45. Tuple of (keys_array, values_array) for the column's ordinal mapping.
  46. """
  47. stat_value = stats[f"unique_values({input_col})"]
  48. if isinstance(stat_value, dict):
  49. # Stats are in pandas dict format - convert to Arrow format
  50. sorted_keys = sorted(stat_value.keys())
  51. keys_array = pa.array(sorted_keys)
  52. values_array = pa.array([stat_value[k] for k in sorted_keys], type=pa.int64())
  53. else:
  54. # Stats are in Arrow tuple format: (keys_array, values_array)
  55. keys_array, values_array = stat_value
  56. return keys_array, values_array
  57. @PublicAPI(stability="alpha")
  58. @SerializablePreprocessor(version=1, identifier="io.ray.preprocessors.ordinal_encoder")
  59. class OrdinalEncoder(SerializablePreprocessorBase):
  60. r"""Encode values within columns as ordered integer values.
  61. :class:`OrdinalEncoder` encodes categorical features as integers that range from
  62. :math:`0` to :math:`n - 1`, where :math:`n` is the number of categories.
  63. If you transform a value that isn't in the fitted datset, then the value is encoded
  64. as ``float("nan")``.
  65. Columns must contain either hashable values or lists of hashable values. Also, you
  66. can't have both scalars and lists in the same column.
  67. Examples:
  68. Use :class:`OrdinalEncoder` to encode categorical features as integers.
  69. >>> import pandas as pd
  70. >>> import ray
  71. >>> from ray.data.preprocessors import OrdinalEncoder
  72. >>> df = pd.DataFrame({
  73. ... "sex": ["male", "female", "male", "female"],
  74. ... "level": ["L4", "L5", "L3", "L4"],
  75. ... })
  76. >>> ds = ray.data.from_pandas(df) # doctest: +SKIP
  77. >>> encoder = OrdinalEncoder(columns=["sex", "level"])
  78. >>> encoder.fit_transform(ds).to_pandas() # doctest: +SKIP
  79. sex level
  80. 0 1 1
  81. 1 0 2
  82. 2 1 0
  83. 3 0 1
  84. :class:`OrdinalEncoder` can also be used in append mode by providing the
  85. name of the output_columns that should hold the encoded values.
  86. >>> encoder = OrdinalEncoder(columns=["sex", "level"], output_columns=["sex_encoded", "level_encoded"])
  87. >>> encoder.fit_transform(ds).to_pandas() # doctest: +SKIP
  88. sex level sex_encoded level_encoded
  89. 0 male L4 1 1
  90. 1 female L5 0 2
  91. 2 male L3 1 0
  92. 3 female L4 0 1
  93. If you transform a value not present in the original dataset, then the value
  94. is encoded as ``float("nan")``.
  95. >>> df = pd.DataFrame({"sex": ["female"], "level": ["L6"]})
  96. >>> ds = ray.data.from_pandas(df) # doctest: +SKIP
  97. >>> encoder.transform(ds).to_pandas() # doctest: +SKIP
  98. sex level
  99. 0 0 NaN
  100. :class:`OrdinalEncoder` can also encode categories in a list.
  101. >>> df = pd.DataFrame({
  102. ... "name": ["Shaolin Soccer", "Moana", "The Smartest Guys in the Room"],
  103. ... "genre": [
  104. ... ["comedy", "action", "sports"],
  105. ... ["animation", "comedy", "action"],
  106. ... ["documentary"],
  107. ... ],
  108. ... })
  109. >>> ds = ray.data.from_pandas(df) # doctest: +SKIP
  110. >>> encoder = OrdinalEncoder(columns=["genre"])
  111. >>> encoder.fit_transform(ds).to_pandas() # doctest: +SKIP
  112. name genre
  113. 0 Shaolin Soccer [2, 0, 4]
  114. 1 Moana [1, 2, 0]
  115. 2 The Smartest Guys in the Room [3]
  116. Args:
  117. columns: The columns to separately encode.
  118. encode_lists: If ``True``, encode list elements. If ``False``, encode
  119. whole lists (i.e., replace each list with an integer). ``True``
  120. by default.
  121. output_columns: The names of the transformed columns. If None, the transformed
  122. columns will be the same as the input columns. If not None, the length of
  123. ``output_columns`` must match the length of ``columns``, othwerwise an error
  124. will be raised.
  125. .. seealso::
  126. :class:`OneHotEncoder`
  127. Another preprocessor that encodes categorical data.
  128. """
  129. def __init__(
  130. self,
  131. columns: List[str],
  132. *,
  133. encode_lists: bool = True,
  134. output_columns: Optional[List[str]] = None,
  135. ):
  136. super().__init__()
  137. # TODO: allow user to specify order of values within each column.
  138. self.columns = columns
  139. self.encode_lists = encode_lists
  140. self.output_columns = Preprocessor._derive_and_validate_output_columns(
  141. columns, output_columns
  142. )
  143. def _fit(self, dataset: "Dataset") -> Preprocessor:
  144. self.stat_computation_plan.add_callable_stat(
  145. stat_fn=lambda key_gen: compute_unique_value_indices(
  146. dataset=dataset,
  147. columns=self.columns,
  148. encode_lists=self.encode_lists,
  149. key_gen=key_gen,
  150. ),
  151. post_process_fn=unique_post_fn(),
  152. stat_key_fn=lambda col: f"unique({col})",
  153. post_key_fn=lambda col: f"unique_values({col})",
  154. columns=self.columns,
  155. )
  156. return self
  157. def _get_ordinal_map(self, column_name: str) -> Dict[Any, int]:
  158. """Get the ordinal mapping for a column as a dict.
  159. Stats can be stored in either:
  160. - Dict format: {value: index} (from pandas-style processing)
  161. - Arrow format: (keys_array, values_array) tuple
  162. This method returns a dict in either case.
  163. """
  164. stat_value = self.stats_[f"unique_values({column_name})"]
  165. if isinstance(stat_value, dict):
  166. return stat_value
  167. # Arrow tuple format (keys_array, values_array)
  168. keys_array, values_array = stat_value
  169. return {k.as_py(): v.as_py() for k, v in zip(keys_array, values_array)}
  170. def _get_arrow_arrays(self, input_col: str) -> Tuple[pa.Array, pa.Array]:
  171. """Get Arrow arrays for keys and values."""
  172. return _get_unique_value_arrow_arrays(self.stats_, input_col)
  173. def _encode_list_element(self, element: list, *, column_name: str):
  174. ordinal_map = self._get_ordinal_map(column_name)
  175. # If encoding lists, entire column is flattened, hence we map individual
  176. # elements inside the list element (of the column)
  177. if self.encode_lists:
  178. return [ordinal_map.get(x) for x in element]
  179. return ordinal_map.get(tuple(element))
  180. def _transform_pandas(self, df: pd.DataFrame):
  181. _validate_df(df, *self.columns)
  182. def column_ordinal_encoder(s: pd.Series):
  183. if _is_series_composed_of_lists(s):
  184. return s.map(
  185. lambda elem: self._encode_list_element(elem, column_name=s.name)
  186. )
  187. s_values = self._get_ordinal_map(s.name)
  188. return s.map(s_values)
  189. df[self.output_columns] = df[self.columns].apply(column_ordinal_encoder)
  190. return df
  191. def _transform_arrow(self, table: pa.Table) -> pa.Table:
  192. """Transform using fast native PyArrow operations for scalar columns.
  193. List-type columns are preferably handled by _transform_pandas, which is selected
  194. via _determine_transform_to_use when a PyArrow schema is available. However,
  195. for pandas-backed datasets (PandasBlockSchema), we can't detect list columns
  196. until runtime, so we fall back to pandas here if list columns are found.
  197. """
  198. # Validate that columns don't contain null values (consistent with pandas path)
  199. _validate_arrow(table, *self.columns)
  200. # Check for list columns (runtime fallback for PandasBlockSchema datasets)
  201. for col_name in self.columns:
  202. col_type = table.schema.field(col_name).type
  203. if pa.types.is_list(col_type) or pa.types.is_large_list(col_type):
  204. # Fall back to pandas transform for list columns
  205. df = table.to_pandas()
  206. result_df = self._transform_pandas(df)
  207. return pa.Table.from_pandas(result_df, preserve_index=False)
  208. for input_col, output_col in zip(self.columns, self.output_columns):
  209. column = table.column(input_col)
  210. encoded_column = self._encode_column_vectorized(column, input_col)
  211. table = BlockAccessor.for_block(table).upsert_column(
  212. output_col, encoded_column
  213. )
  214. return table
  215. def _encode_column_vectorized(
  216. self, column: pa.ChunkedArray, input_col: str
  217. ) -> pa.Array:
  218. """Encode column using PyArrow's vectorized pc.index_in.
  219. Unseen categories are encoded as null in the output, which becomes NaN
  220. when converted to pandas. Null values should be validated before calling
  221. this method via _validate_arrow.
  222. """
  223. keys_array, values_array = self._get_arrow_arrays(input_col)
  224. if keys_array.type != column.type:
  225. keys_array = pc.cast(keys_array, column.type)
  226. # pc.index_in returns null for values not found in keys_array
  227. # (including null input values and unseen categories)
  228. indices = pc.index_in(column, keys_array)
  229. # pc.take preserves nulls from indices, so null inputs -> null outputs
  230. return pc.take(values_array, indices)
  231. @classmethod
  232. @DeveloperAPI
  233. def preferred_batch_format(cls) -> BatchFormat:
  234. return BatchFormat.ARROW
  235. def _get_serializable_fields(self) -> Dict[str, Any]:
  236. return {
  237. "columns": self.columns,
  238. "output_columns": self.output_columns,
  239. "encode_lists": self.encode_lists,
  240. "_fitted": getattr(self, "_fitted", None),
  241. }
  242. def _set_serializable_fields(self, fields: Dict[str, Any], version: int):
  243. # required fields
  244. self.columns = fields["columns"]
  245. self.output_columns = fields["output_columns"]
  246. self.encode_lists = fields["encode_lists"]
  247. # optional fields
  248. self._fitted = fields.get("_fitted")
  249. def __repr__(self):
  250. return (
  251. f"{self.__class__.__name__}(columns={self.columns!r}, "
  252. f"encode_lists={self.encode_lists!r}, "
  253. f"output_columns={self.output_columns!r})"
  254. )
  255. @PublicAPI(stability="alpha")
  256. @SerializablePreprocessor(version=1, identifier="io.ray.preprocessors.one_hot_encoder")
  257. class OneHotEncoder(SerializablePreprocessorBase):
  258. r"""`One-hot encode <https://en.wikipedia.org/wiki/One-hot#Machine_learning_and_statistics>`_
  259. categorical data.
  260. This preprocessor transforms each specified column into a one-hot encoded vector.
  261. Each element in the vector corresponds to a unique category in the column, with a
  262. value of 1 if the category matches and 0 otherwise.
  263. If a category is infrequent (based on ``max_categories``) or not present in the
  264. fitted dataset, it is encoded as all 0s.
  265. Columns must contain hashable objects or lists of hashable objects.
  266. .. note::
  267. Lists are treated as categories. If you want to encode individual list
  268. elements, use :class:`MultiHotEncoder`.
  269. Example:
  270. >>> import pandas as pd
  271. >>> import ray
  272. >>> from ray.data.preprocessors import OneHotEncoder
  273. >>>
  274. >>> df = pd.DataFrame({"color": ["red", "green", "red", "red", "blue", "green"]})
  275. >>> ds = ray.data.from_pandas(df) # doctest: +SKIP
  276. >>> encoder = OneHotEncoder(columns=["color"])
  277. >>> encoder.fit_transform(ds).to_pandas() # doctest: +SKIP
  278. color
  279. 0 [0, 0, 1]
  280. 1 [0, 1, 0]
  281. 2 [0, 0, 1]
  282. 3 [0, 0, 1]
  283. 4 [1, 0, 0]
  284. 5 [0, 1, 0]
  285. OneHotEncoder can also be used in append mode by providing the
  286. name of the output_columns that should hold the encoded values.
  287. >>> encoder = OneHotEncoder(columns=["color"], output_columns=["color_encoded"])
  288. >>> encoder.fit_transform(ds).to_pandas() # doctest: +SKIP
  289. color color_encoded
  290. 0 red [0, 0, 1]
  291. 1 green [0, 1, 0]
  292. 2 red [0, 0, 1]
  293. 3 red [0, 0, 1]
  294. 4 blue [1, 0, 0]
  295. 5 green [0, 1, 0]
  296. If you one-hot encode a value that isn't in the fitted dataset, then the
  297. value is encoded with zeros.
  298. >>> df = pd.DataFrame({"color": ["yellow"]})
  299. >>> batch = ray.data.from_pandas(df) # doctest: +SKIP
  300. >>> encoder.transform(batch).to_pandas() # doctest: +SKIP
  301. color color_encoded
  302. 0 yellow [0, 0, 0]
  303. Likewise, if you one-hot encode an infrequent value, then the value is encoded
  304. with zeros.
  305. >>> encoder = OneHotEncoder(columns=["color"], max_categories={"color": 2})
  306. >>> encoder.fit_transform(ds).to_pandas() # doctest: +SKIP
  307. color
  308. 0 [1, 0]
  309. 1 [0, 1]
  310. 2 [1, 0]
  311. 3 [1, 0]
  312. 4 [0, 0]
  313. 5 [0, 1]
  314. Args:
  315. columns: The columns to separately encode.
  316. max_categories: The maximum number of features to create for each column.
  317. If a value isn't specified for a column, then a feature is created
  318. for every category in that column.
  319. output_columns: The names of the transformed columns. If None, the transformed
  320. columns will be the same as the input columns. If not None, the length of
  321. ``output_columns`` must match the length of ``columns``, othwerwise an error
  322. will be raised.
  323. .. seealso::
  324. :class:`MultiHotEncoder`
  325. If you want to encode individual list elements, use
  326. :class:`MultiHotEncoder`.
  327. :class:`OrdinalEncoder`
  328. If your categories are ordered, you may want to use
  329. :class:`OrdinalEncoder`.
  330. """ # noqa: E501
  331. def __init__(
  332. self,
  333. columns: List[str],
  334. *,
  335. max_categories: Optional[Dict[str, int]] = None,
  336. output_columns: Optional[List[str]] = None,
  337. ):
  338. super().__init__()
  339. # TODO: add `drop` parameter.
  340. self.columns = columns
  341. self.max_categories = max_categories or {}
  342. self.output_columns = Preprocessor._derive_and_validate_output_columns(
  343. columns, output_columns
  344. )
  345. def _fit(self, dataset: "Dataset") -> Preprocessor:
  346. self.stat_computation_plan.add_callable_stat(
  347. stat_fn=lambda key_gen: compute_unique_value_indices(
  348. dataset=dataset,
  349. columns=self.columns,
  350. encode_lists=False,
  351. key_gen=key_gen,
  352. max_categories=self.max_categories,
  353. ),
  354. post_process_fn=unique_post_fn(),
  355. stat_key_fn=lambda col: f"unique({col})",
  356. post_key_fn=lambda col: f"unique_values({col})",
  357. columns=self.columns,
  358. )
  359. return self
  360. @classmethod
  361. @DeveloperAPI
  362. def preferred_batch_format(cls) -> BatchFormat:
  363. return BatchFormat.ARROW
  364. def safe_get(self, v: Any, stats: Dict[str, int]):
  365. if isinstance(v, (list, np.ndarray)):
  366. v = tuple(v)
  367. if isinstance(v, Hashable):
  368. return stats.get(v, -1)
  369. else:
  370. return -1 # Unhashable type treated as a missing category
  371. def _transform_pandas(self, df: pd.DataFrame):
  372. _validate_df(df, *self.columns)
  373. # Compute new one-hot encoded columns
  374. for column, output_column in zip(self.columns, self.output_columns):
  375. stats = self.stats_[f"unique_values({column})"]
  376. num_categories = len(stats)
  377. one_hot = np.zeros((len(df), num_categories), dtype=np.uint8)
  378. # Integer indices for each category in the column
  379. codes = df[column].apply(lambda v: self.safe_get(v, stats)).to_numpy()
  380. # Filter to only the rows that have a valid category
  381. valid_category_mask = codes != -1
  382. # Dimension should be (num_rows, ) - 1D boolean array
  383. non_zero_indices = np.nonzero(valid_category_mask)[0]
  384. # Mark the corresponding categories as 1
  385. one_hot[
  386. non_zero_indices,
  387. codes[valid_category_mask],
  388. ] = 1
  389. df[output_column] = one_hot.tolist()
  390. return df
  391. def _transform_arrow(self, table: pa.Table) -> pa.Table:
  392. """Transform using fast native PyArrow operations for scalar columns.
  393. List-type columns are preferably handled by _transform_pandas, which is selected
  394. via _determine_transform_to_use when a PyArrow schema is available. However,
  395. for pandas-backed datasets (PandasBlockSchema), we can't detect list columns
  396. until runtime, so we fall back to pandas here if list columns are found.
  397. """
  398. # Validate that columns don't contain null values (consistent with pandas path)
  399. _validate_arrow(table, *self.columns)
  400. # Check for list columns (runtime fallback for PandasBlockSchema datasets)
  401. for col_name in self.columns:
  402. col_type = table.schema.field(col_name).type
  403. if pa.types.is_list(col_type) or pa.types.is_large_list(col_type):
  404. # Fall back to pandas transform for list columns
  405. df = table.to_pandas()
  406. result_df = self._transform_pandas(df)
  407. return pa.Table.from_pandas(result_df, preserve_index=False)
  408. for input_col, output_col in zip(self.columns, self.output_columns):
  409. column = table.column(input_col)
  410. encoded_column = self._encode_column_one_hot(column, input_col)
  411. table = BlockAccessor.for_block(table).upsert_column(
  412. output_col, encoded_column
  413. )
  414. return table
  415. def _get_arrow_arrays(self, input_col: str) -> Tuple[pa.Array, pa.Array]:
  416. """Get Arrow arrays for keys and values."""
  417. return _get_unique_value_arrow_arrays(self.stats_, input_col)
  418. def _encode_column_one_hot(
  419. self, column: pa.ChunkedArray, input_col: str
  420. ) -> pa.FixedSizeListArray:
  421. """Encode a column to one-hot vectors using Arrow arrays.
  422. Unseen categories are encoded as all-zeros vectors, matching the pandas
  423. behavior. Null values should be validated before calling this method
  424. via _validate_arrow.
  425. """
  426. keys_array, _ = self._get_arrow_arrays(input_col)
  427. num_categories = len(keys_array)
  428. # Cast keys to match column type if needed
  429. if keys_array.type != column.type:
  430. keys_array = pc.cast(keys_array, column.type)
  431. # Use pc.index_in to find position of each value in keys_array
  432. # Returns null for null inputs and unseen categories (values not in keys_array)
  433. indices = pc.index_in(column, keys_array)
  434. # Fill nulls with -1 so they can be filtered out below (resulting in all-zeros)
  435. indices_filled = pc.fill_null(indices, -1)
  436. # Create one-hot encoded matrix using vectorized NumPy operations
  437. num_rows = len(column)
  438. indices_np = indices_filled.to_numpy()
  439. one_hot_matrix = np.zeros((num_rows, num_categories), dtype=np.uint8)
  440. # Find valid indices (not -1) and set 1s at the appropriate positions
  441. valid_mask = indices_np != -1
  442. valid_indices = np.nonzero(valid_mask)[0]
  443. if len(valid_indices) > 0:
  444. one_hot_matrix[valid_indices, indices_np[valid_mask]] = 1
  445. # Convert to Arrow FixedSizeListArray for efficient storage
  446. return pa.FixedSizeListArray.from_arrays(one_hot_matrix.ravel(), num_categories)
  447. def _get_serializable_fields(self) -> Dict[str, Any]:
  448. return {
  449. "columns": self.columns,
  450. "output_columns": self.output_columns,
  451. "max_categories": self.max_categories,
  452. "_fitted": getattr(self, "_fitted", None),
  453. }
  454. def _set_serializable_fields(self, fields: Dict[str, Any], version: int):
  455. # required fields
  456. self.columns = fields["columns"]
  457. self.output_columns = fields["output_columns"]
  458. self.max_categories = fields["max_categories"]
  459. # optional fields
  460. self._fitted = fields.get("_fitted")
  461. def __repr__(self):
  462. return (
  463. f"{self.__class__.__name__}(columns={self.columns!r}, "
  464. f"max_categories={self.max_categories!r}, "
  465. f"output_columns={self.output_columns!r})"
  466. )
  467. @PublicAPI(stability="alpha")
  468. @SerializablePreprocessor(
  469. version=1, identifier="io.ray.preprocessors.multi_hot_encoder"
  470. )
  471. class MultiHotEncoder(SerializablePreprocessorBase):
  472. r"""Multi-hot encode categorical data.
  473. This preprocessor replaces each list of categories with an :math:`m`-length binary
  474. list, where :math:`m` is the number of unique categories in the column or the value
  475. specified in ``max_categories``. The :math:`i\\text{-th}` element of the binary list
  476. is :math:`1` if category :math:`i` is in the input list and :math:`0` otherwise.
  477. Columns must contain hashable objects or lists of hashable objects.
  478. Also, you can't have both types in the same column.
  479. .. note::
  480. The logic is similar to scikit-learn's [MultiLabelBinarizer][1]
  481. Examples:
  482. >>> import pandas as pd
  483. >>> import ray
  484. >>> from ray.data.preprocessors import MultiHotEncoder
  485. >>>
  486. >>> df = pd.DataFrame({
  487. ... "name": ["Shaolin Soccer", "Moana", "The Smartest Guys in the Room"],
  488. ... "genre": [
  489. ... ["comedy", "action", "sports"],
  490. ... ["animation", "comedy", "action"],
  491. ... ["documentary"],
  492. ... ],
  493. ... })
  494. >>> ds = ray.data.from_pandas(df) # doctest: +SKIP
  495. >>>
  496. >>> encoder = MultiHotEncoder(columns=["genre"])
  497. >>> encoder.fit_transform(ds).to_pandas() # doctest: +SKIP
  498. name genre
  499. 0 Shaolin Soccer [1, 0, 1, 0, 1]
  500. 1 Moana [1, 1, 1, 0, 0]
  501. 2 The Smartest Guys in the Room [0, 0, 0, 1, 0]
  502. :class:`MultiHotEncoder` can also be used in append mode by providing the
  503. name of the output_columns that should hold the encoded values.
  504. >>> encoder = MultiHotEncoder(columns=["genre"], output_columns=["genre_encoded"])
  505. >>> encoder.fit_transform(ds).to_pandas() # doctest: +SKIP
  506. name genre genre_encoded
  507. 0 Shaolin Soccer [comedy, action, sports] [1, 0, 1, 0, 1]
  508. 1 Moana [animation, comedy, action] [1, 1, 1, 0, 0]
  509. 2 The Smartest Guys in the Room [documentary] [0, 0, 0, 1, 0]
  510. If you specify ``max_categories``, then :class:`MultiHotEncoder`
  511. creates features for only the most frequent categories.
  512. >>> encoder = MultiHotEncoder(columns=["genre"], max_categories={"genre": 3})
  513. >>> encoder.fit_transform(ds).to_pandas() # doctest: +SKIP
  514. name genre
  515. 0 Shaolin Soccer [1, 1, 1]
  516. 1 Moana [1, 1, 0]
  517. 2 The Smartest Guys in the Room [0, 0, 0]
  518. >>> encoder.stats_ # doctest: +SKIP
  519. OrderedDict([('unique_values(genre)', {'comedy': 0, 'action': 1, 'sports': 2})])
  520. Args:
  521. columns: The columns to separately encode.
  522. max_categories: The maximum number of features to create for each column.
  523. If a value isn't specified for a column, then a feature is created
  524. for every unique category in that column.
  525. output_columns: The names of the transformed columns. If None, the transformed
  526. columns will be the same as the input columns. If not None, the length of
  527. ``output_columns`` must match the length of ``columns``, othwerwise an error
  528. will be raised.
  529. .. seealso::
  530. :class:`OneHotEncoder`
  531. If you're encoding individual categories instead of lists of
  532. categories, use :class:`OneHotEncoder`.
  533. :class:`OrdinalEncoder`
  534. If your categories are ordered, you may want to use
  535. :class:`OrdinalEncoder`.
  536. [1]: https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MultiLabelBinarizer.html
  537. """
  538. def __init__(
  539. self,
  540. columns: List[str],
  541. *,
  542. max_categories: Optional[Dict[str, int]] = None,
  543. output_columns: Optional[List[str]] = None,
  544. ):
  545. super().__init__()
  546. # TODO: add `drop` parameter.
  547. self.columns = columns
  548. self.max_categories = max_categories or {}
  549. self.output_columns = Preprocessor._derive_and_validate_output_columns(
  550. columns, output_columns
  551. )
  552. def _fit(self, dataset: "Dataset") -> Preprocessor:
  553. self.stat_computation_plan.add_callable_stat(
  554. stat_fn=lambda key_gen: compute_unique_value_indices(
  555. dataset=dataset,
  556. columns=self.columns,
  557. encode_lists=True,
  558. key_gen=key_gen,
  559. max_categories=self.max_categories,
  560. ),
  561. post_process_fn=unique_post_fn(),
  562. stat_key_fn=lambda col: f"unique({col})",
  563. post_key_fn=lambda col: f"unique_values({col})",
  564. columns=self.columns,
  565. )
  566. return self
  567. def _transform_pandas(self, df: pd.DataFrame):
  568. _validate_df(df, *self.columns)
  569. def encode_list(element: list, *, name: str):
  570. if isinstance(element, np.ndarray):
  571. element = element.tolist()
  572. elif not isinstance(element, list):
  573. element = [element]
  574. stats = self.stats_[f"unique_values({name})"]
  575. counter = Counter(element)
  576. return [counter.get(x, 0) for x in stats]
  577. for column, output_column in zip(self.columns, self.output_columns):
  578. df[output_column] = df[column].map(partial(encode_list, name=column))
  579. return df
  580. def _get_serializable_fields(self) -> Dict[str, Any]:
  581. return {
  582. "columns": self.columns,
  583. "output_columns": self.output_columns,
  584. "max_categories": self.max_categories,
  585. "_fitted": getattr(self, "_fitted", None),
  586. }
  587. def _set_serializable_fields(self, fields: Dict[str, Any], version: int):
  588. # required fields
  589. self.columns = fields["columns"]
  590. self.output_columns = fields["output_columns"]
  591. self.max_categories = fields["max_categories"]
  592. # optional fields
  593. self._fitted = fields.get("_fitted")
  594. def __repr__(self):
  595. return (
  596. f"{self.__class__.__name__}(columns={self.columns!r}, "
  597. f"max_categories={self.max_categories!r}, "
  598. f"output_columns={self.output_columns})"
  599. )
  600. @PublicAPI(stability="alpha")
  601. @SerializablePreprocessor(version=1, identifier="io.ray.preprocessors.label_encoder")
  602. class LabelEncoder(SerializablePreprocessorBase):
  603. r"""Encode labels as integer targets.
  604. :class:`LabelEncoder` encodes labels as integer targets that range from
  605. :math:`0` to :math:`n - 1`, where :math:`n` is the number of unique labels.
  606. If you transform a label that isn't in the fitted datset, then the label is encoded
  607. as ``float("nan")``.
  608. Examples:
  609. >>> import pandas as pd
  610. >>> import ray
  611. >>> df = pd.DataFrame({
  612. ... "sepal_width": [5.1, 7, 4.9, 6.2],
  613. ... "sepal_height": [3.5, 3.2, 3, 3.4],
  614. ... "species": ["setosa", "versicolor", "setosa", "virginica"]
  615. ... })
  616. >>> ds = ray.data.from_pandas(df) # doctest: +SKIP
  617. >>>
  618. >>> from ray.data.preprocessors import LabelEncoder
  619. >>> encoder = LabelEncoder(label_column="species")
  620. >>> encoder.fit_transform(ds).to_pandas() # doctest: +SKIP
  621. sepal_width sepal_height species
  622. 0 5.1 3.5 0
  623. 1 7.0 3.2 1
  624. 2 4.9 3.0 0
  625. 3 6.2 3.4 2
  626. You can also provide the name of the output column that should hold the encoded
  627. labels if you want to use :class:`LabelEncoder` in append mode.
  628. >>> encoder = LabelEncoder(label_column="species", output_column="species_encoded")
  629. >>> encoder.fit_transform(ds).to_pandas() # doctest: +SKIP
  630. sepal_width sepal_height species species_encoded
  631. 0 5.1 3.5 setosa 0
  632. 1 7.0 3.2 versicolor 1
  633. 2 4.9 3.0 setosa 0
  634. 3 6.2 3.4 virginica 2
  635. If you transform a label not present in the original dataset, then the new
  636. label is encoded as ``float("nan")``.
  637. >>> df = pd.DataFrame({
  638. ... "sepal_width": [4.2],
  639. ... "sepal_height": [2.7],
  640. ... "species": ["bracteata"]
  641. ... })
  642. >>> ds = ray.data.from_pandas(df) # doctest: +SKIP
  643. >>> encoder.transform(ds).to_pandas() # doctest: +SKIP
  644. sepal_width sepal_height species
  645. 0 4.2 2.7 NaN
  646. Args:
  647. label_column: A column containing labels that you want to encode.
  648. output_column: The name of the column that will contain the encoded
  649. labels. If None, the output column will have the same name as the
  650. input column.
  651. .. seealso::
  652. :class:`OrdinalEncoder`
  653. If you're encoding ordered features, use :class:`OrdinalEncoder` instead of
  654. :class:`LabelEncoder`.
  655. """
  656. def __init__(self, label_column: str, *, output_column: Optional[str] = None):
  657. super().__init__()
  658. self.label_column = label_column
  659. self.output_column = output_column or label_column
  660. def _fit(self, dataset: "Dataset") -> Preprocessor:
  661. self.stat_computation_plan.add_callable_stat(
  662. stat_fn=lambda key_gen: compute_unique_value_indices(
  663. dataset=dataset,
  664. columns=[self.label_column],
  665. key_gen=key_gen,
  666. ),
  667. post_process_fn=unique_post_fn(),
  668. stat_key_fn=lambda col: f"unique({col})",
  669. post_key_fn=lambda col: f"unique_values({col})",
  670. columns=[self.label_column],
  671. )
  672. return self
  673. def _transform_pandas(self, df: pd.DataFrame):
  674. _validate_df(df, self.label_column)
  675. def column_label_encoder(s: pd.Series):
  676. s_values = self.stats_[f"unique_values({s.name})"]
  677. return s.map(s_values)
  678. df[self.output_column] = df[self.label_column].transform(column_label_encoder)
  679. return df
  680. def inverse_transform(self, ds: "Dataset") -> "Dataset":
  681. """Inverse transform the given dataset.
  682. Args:
  683. ds: Input Dataset that has been fitted and/or transformed.
  684. Returns:
  685. ray.data.Dataset: The inverse transformed Dataset.
  686. Raises:
  687. PreprocessorNotFittedException: if ``fit`` is not called yet.
  688. """
  689. fit_status = self.fit_status()
  690. if fit_status in (
  691. Preprocessor.FitStatus.PARTIALLY_FITTED,
  692. Preprocessor.FitStatus.NOT_FITTED,
  693. ):
  694. raise PreprocessorNotFittedException(
  695. "`fit` must be called before `inverse_transform`, "
  696. )
  697. kwargs = self._get_transform_config()
  698. return ds.map_batches(
  699. self._inverse_transform_pandas, batch_format=BatchFormat.PANDAS, **kwargs
  700. )
  701. def _inverse_transform_pandas(self, df: pd.DataFrame):
  702. def column_label_decoder(s: pd.Series):
  703. inverse_values = {
  704. value: key
  705. for key, value in self.stats_[
  706. f"unique_values({self.label_column})"
  707. ].items()
  708. }
  709. return s.map(inverse_values)
  710. df[self.label_column] = df[self.output_column].transform(column_label_decoder)
  711. return df
  712. def get_input_columns(self) -> List[str]:
  713. return [self.label_column]
  714. def get_output_columns(self) -> List[str]:
  715. return [self.output_column]
  716. def _get_serializable_fields(self) -> Dict[str, Any]:
  717. return {
  718. "label_column": self.label_column,
  719. "output_column": self.output_column,
  720. "_fitted": getattr(self, "_fitted", None),
  721. }
  722. def _set_serializable_fields(self, fields: Dict[str, Any], version: int):
  723. # required fields
  724. self.label_column = fields["label_column"]
  725. self.output_column = fields["output_column"]
  726. # optional fields
  727. self._fitted = fields.get("_fitted")
  728. def __repr__(self):
  729. return f"{self.__class__.__name__}(label_column={self.label_column!r}, output_column={self.output_column!r})"
  730. @PublicAPI(stability="alpha")
  731. @SerializablePreprocessor(version=1, identifier="io.ray.preprocessors.categorizer")
  732. class Categorizer(SerializablePreprocessorBase):
  733. r"""Convert columns to ``pd.CategoricalDtype``.
  734. Use this preprocessor with frameworks that have built-in support for
  735. ``pd.CategoricalDtype`` like LightGBM.
  736. .. warning::
  737. If you don't specify ``dtypes``, fit this preprocessor before splitting
  738. your dataset into train and test splits. This ensures categories are
  739. consistent across splits.
  740. Examples:
  741. >>> import pandas as pd
  742. >>> import ray
  743. >>> from ray.data.preprocessors import Categorizer
  744. >>>
  745. >>> df = pd.DataFrame(
  746. ... {
  747. ... "sex": ["male", "female", "male", "female"],
  748. ... "level": ["L4", "L5", "L3", "L4"],
  749. ... })
  750. >>> ds = ray.data.from_pandas(df) # doctest: +SKIP
  751. >>> categorizer = Categorizer(columns=["sex", "level"])
  752. >>> categorizer.fit_transform(ds).schema().types # doctest: +SKIP
  753. [CategoricalDtype(categories=['female', 'male'], ordered=False), CategoricalDtype(categories=['L3', 'L4', 'L5'], ordered=False)]
  754. :class:`Categorizer` can also be used in append mode by providing the
  755. name of the output_columns that should hold the categorized values.
  756. >>> categorizer = Categorizer(columns=["sex", "level"], output_columns=["sex_cat", "level_cat"])
  757. >>> categorizer.fit_transform(ds).to_pandas() # doctest: +SKIP
  758. sex level sex_cat level_cat
  759. 0 male L4 male L4
  760. 1 female L5 female L5
  761. 2 male L3 male L3
  762. 3 female L4 female L4
  763. If you know the categories in advance, you can specify the categories with the
  764. ``dtypes`` parameter.
  765. >>> categorizer = Categorizer(
  766. ... columns=["sex", "level"],
  767. ... dtypes={"level": pd.CategoricalDtype(["L3", "L4", "L5", "L6"], ordered=True)},
  768. ... )
  769. >>> categorizer.fit_transform(ds).schema().types # doctest: +SKIP
  770. [CategoricalDtype(categories=['female', 'male'], ordered=False), CategoricalDtype(categories=['L3', 'L4', 'L5', 'L6'], ordered=True)]
  771. Args:
  772. columns: The columns to convert to ``pd.CategoricalDtype``.
  773. dtypes: An optional dictionary that maps columns to ``pd.CategoricalDtype``
  774. objects. If you don't include a column in ``dtypes``, the categories
  775. are inferred.
  776. output_columns: The names of the transformed columns. If None, the transformed
  777. columns will be the same as the input columns. If not None, the length of
  778. ``output_columns`` must match the length of ``columns``, othwerwise an error
  779. will be raised.
  780. """ # noqa: E501
  781. def __init__(
  782. self,
  783. columns: List[str],
  784. dtypes: Optional[Dict[str, pd.CategoricalDtype]] = None,
  785. output_columns: Optional[List[str]] = None,
  786. ):
  787. super().__init__()
  788. if not dtypes:
  789. dtypes = {}
  790. self.columns = columns
  791. self.dtypes = dtypes
  792. self.output_columns = Preprocessor._derive_and_validate_output_columns(
  793. columns, output_columns
  794. )
  795. def _fit(self, dataset: "Dataset") -> Preprocessor:
  796. columns_to_get = [
  797. column for column in self.columns if column not in self.dtypes
  798. ]
  799. self.stats_ |= self.dtypes
  800. if not columns_to_get:
  801. return self
  802. def callback(unique_indices: Dict[str, Dict]) -> pd.CategoricalDtype:
  803. return pd.CategoricalDtype(unique_indices.keys())
  804. self.stat_computation_plan.add_callable_stat(
  805. stat_fn=lambda key_gen: compute_unique_value_indices(
  806. dataset=dataset,
  807. columns=columns_to_get,
  808. key_gen=key_gen,
  809. ),
  810. post_process_fn=make_post_processor(
  811. base_fn=unique_post_fn(drop_na_values=True),
  812. callbacks=[callback],
  813. ),
  814. stat_key_fn=lambda col: f"unique({col})",
  815. post_key_fn=lambda col: col,
  816. columns=columns_to_get,
  817. )
  818. return self
  819. def _transform_pandas(self, df: pd.DataFrame):
  820. df[self.output_columns] = df[self.columns].astype(self.stats_)
  821. return df
  822. def _get_serializable_fields(self) -> Dict[str, Any]:
  823. return {
  824. "columns": self.columns,
  825. "output_columns": self.output_columns,
  826. "_fitted": getattr(self, "_fitted", None),
  827. "dtypes": {
  828. col: {"categories": list(dtype.categories), "ordered": dtype.ordered}
  829. for col, dtype in self.dtypes.items()
  830. }
  831. if hasattr(self, "dtypes") and self.dtypes
  832. else None,
  833. }
  834. def _set_serializable_fields(self, fields: Dict[str, Any], version: int):
  835. # required fields
  836. # Handle dtypes field specially
  837. self.dtypes = (
  838. {
  839. col: pd.CategoricalDtype(
  840. categories=dtype_data["categories"], ordered=dtype_data["ordered"]
  841. )
  842. for col, dtype_data in fields["dtypes"].items()
  843. }
  844. if fields.get("dtypes")
  845. else {}
  846. )
  847. self.columns = fields["columns"]
  848. self.output_columns = fields["output_columns"]
  849. # optional fields
  850. self._fitted = fields.get("_fitted")
  851. def __repr__(self):
  852. return (
  853. f"{self.__class__.__name__}(columns={self.columns!r}, "
  854. f"dtypes={self.dtypes!r}, output_columns={self.output_columns!r})"
  855. )
  856. def compute_unique_value_indices(
  857. *,
  858. dataset: "Dataset",
  859. columns: List[str],
  860. key_gen: Callable,
  861. encode_lists: bool = True,
  862. max_categories: Optional[Dict[str, int]] = None,
  863. ):
  864. if max_categories is None:
  865. max_categories = {}
  866. columns_set = set(columns)
  867. for column in max_categories:
  868. if column not in columns_set:
  869. raise ValueError(
  870. f"You set `max_categories` for {column}, which is not present in "
  871. f"{columns}."
  872. )
  873. def get_pd_value_counts_per_column(col: pd.Series) -> Dict:
  874. # special handling for lists
  875. if _is_series_composed_of_lists(col):
  876. if encode_lists:
  877. counter = Counter()
  878. def update_counter(element):
  879. counter.update(element)
  880. return element
  881. col.map(update_counter)
  882. return counter
  883. else:
  884. # convert to tuples to make lists hashable
  885. col = col.map(lambda x: tuple(x))
  886. return Counter(col.value_counts(dropna=False).to_dict())
  887. def get_pd_value_counts(df: pd.DataFrame) -> Dict[str, List[Dict]]:
  888. df_columns = df.columns.tolist()
  889. result = {}
  890. for col in columns:
  891. if col in df_columns:
  892. result[col] = [get_pd_value_counts_per_column(df[col])]
  893. else:
  894. raise ValueError(
  895. f"Column '{col}' does not exist in DataFrame, which has columns: {df_columns}" # noqa: E501
  896. )
  897. return result
  898. value_counts_ds = dataset.map_batches(get_pd_value_counts, batch_format="pandas")
  899. unique_values_by_col: Dict[str, Set] = {key_gen(col): set() for col in columns}
  900. for batch in value_counts_ds.iter_batches(batch_size=None):
  901. for col, counters in batch.items():
  902. for counter in counters:
  903. counter: Dict[Any, int] = {
  904. k: v for k, v in counter.items() if v is not None
  905. }
  906. if col in max_categories:
  907. counter: Dict[Any, int] = dict(
  908. Counter(counter).most_common(max_categories[col])
  909. )
  910. # add only column values since frequencies are needed beyond this point
  911. unique_values_by_col[key_gen(col)].update(counter.keys())
  912. return unique_values_by_col
  913. # FIXME: the arrow format path is broken: https://anyscale1.atlassian.net/browse/DATA-1788
  914. def unique_post_fn(
  915. drop_na_values: bool = False, batch_format: BatchFormat = None
  916. ) -> Callable:
  917. """
  918. Returns a post-processing function that generates an encoding map by
  919. sorting the unique values produced during aggregation or stats computation.
  920. Args:
  921. drop_na_values: If True, NA/null values will be silently dropped from the
  922. encoding map. If False, raises an error if any NA/null values are present.
  923. batch_format: Determines the output format of the encoding map.
  924. - If BatchFormat.ARROW: Returns Arrow format (tuple of arrays) for scalar
  925. types, or dict format for list types that PyArrow can't sort.
  926. - Otherwise: Returns pandas dict format {value: index}.
  927. Returns:
  928. A callable that takes unique values and returns an encoding map.
  929. The map format depends on batch_format and input types:
  930. - Dict format: {value: int} - used for pandas path or list-type data
  931. - Arrow format: (keys_array, values_array) - used for Arrow path with scalar data
  932. """
  933. def gen_value_index(values: List) -> Dict[Any, int]:
  934. """
  935. Generate an encoding map from a list of unique values using Python sorting.
  936. Args:
  937. values: List of unique values to encode (can include lists/tuples).
  938. Returns:
  939. Dict mapping each value to a unique integer index.
  940. List values are converted to tuples for hashability.
  941. Raises:
  942. ValueError: If null values are present and drop_na_values is False.
  943. """
  944. # NOTE: We special-case null here since it prevents provided
  945. # values sequence from being sortable
  946. if any(is_null(v) for v in values) and not drop_na_values:
  947. raise ValueError(
  948. "Unable to fit column because it contains null"
  949. " values. Consider imputing missing values first."
  950. )
  951. non_null_values = [v for v in values if not is_null(v)]
  952. return {
  953. (v if not isinstance(v, list) else tuple(v)): i
  954. # NOTE: Sorting applied to produce stable encoding
  955. for i, v in enumerate(sorted(non_null_values))
  956. }
  957. def gen_value_index_arrow_from_arrow(
  958. values: Union["pa.ListScalar", "pa.Array"],
  959. ) -> Union[Tuple["pa.Array", "pa.Array"], Dict[Any, int]]:
  960. """Generate an encoding map from unique values using Arrow-native operations.
  961. Args:
  962. values: The aggregation result as a pa.ListScalar (list of unique values)
  963. or a pa.Array of values directly.
  964. Returns:
  965. For scalar types that PyArrow can sort natively, returns a tuple of
  966. (sorted_keys, indices) as pa.Array. For list types that require fallback,
  967. returns a dict mapping {value: index}.
  968. Note:
  969. PyArrow's sort_indices doesn't support list types, so we fall back to
  970. dict format for columns containing lists. The _transform_arrow method
  971. handles this by detecting dict-format stats and converting as needed.
  972. """
  973. # Handle ListScalar from aggregation result
  974. if isinstance(values, pa.ListScalar):
  975. values = values.values
  976. # Check if values contain list types - PyArrow can't sort these
  977. # Fall back to pandas dict format for list types
  978. if pa.types.is_list(values.type) or pa.types.is_large_list(values.type):
  979. return gen_value_index(values.to_pylist())
  980. # Drop nulls if requested
  981. if drop_na_values:
  982. values = pc.drop_null(values)
  983. else:
  984. if pc.any(pc.is_null(values)).as_py():
  985. raise ValueError(
  986. "Unable to fit column because it contains null"
  987. " values. Consider imputing missing values first."
  988. )
  989. # Sort the values
  990. sorted_indices = pc.sort_indices(values)
  991. sorted_values = pc.take(values, sorted_indices)
  992. # Create the index array
  993. values_array = pa.array(range(len(sorted_values)), type=pa.int64())
  994. return (sorted_values, values_array)
  995. return (
  996. gen_value_index_arrow_from_arrow
  997. if batch_format == BatchFormat.ARROW
  998. else gen_value_index
  999. )
  1000. def _validate_df(df: pd.DataFrame, *columns: str) -> None:
  1001. null_columns = [column for column in columns if df[column].isnull().values.any()]
  1002. if null_columns:
  1003. raise ValueError(
  1004. f"Unable to transform columns {null_columns} because they contain "
  1005. f"null values. Consider imputing missing values first."
  1006. )
  1007. def _validate_arrow(table: pa.Table, *columns: str) -> None:
  1008. """Validate that specified columns in an Arrow table do not contain null values.
  1009. Args:
  1010. table: The Arrow table to validate.
  1011. *columns: Column names to check for null values.
  1012. Raises:
  1013. ValueError: If any of the specified columns contain null values.
  1014. """
  1015. null_columns = [
  1016. column for column in columns if pc.any(pc.is_null(table.column(column))).as_py()
  1017. ]
  1018. if null_columns:
  1019. raise ValueError(
  1020. f"Unable to transform columns {null_columns} because they contain "
  1021. f"null values. Consider imputing missing values first."
  1022. )
  1023. def _is_series_composed_of_lists(series: pd.Series) -> bool:
  1024. # we assume that all elements are a list here
  1025. first_not_none_element = next(
  1026. (element for element in series if element is not None), None
  1027. )
  1028. return pandas.api.types.is_object_dtype(series.dtype) and isinstance(
  1029. first_not_none_element, (list, np.ndarray)
  1030. )