data_batch_conversion.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. import warnings
  2. from enum import Enum
  3. from typing import TYPE_CHECKING, Dict, List, Union
  4. import numpy as np
  5. from ray.air.data_batch_type import DataBatchType
  6. from ray.data.constants import TENSOR_COLUMN_NAME
  7. from ray.util.annotations import DeveloperAPI
  8. if TYPE_CHECKING:
  9. import pandas as pd
  10. # TODO: Consolidate data conversion edges for arrow bug workaround.
  11. try:
  12. import pyarrow
  13. except ImportError:
  14. pyarrow = None
  15. # Lazy import to avoid ray init failures without pandas installed and allow
  16. # dataset to import modules in this file.
  17. _pandas = None
  18. def _lazy_import_pandas():
  19. global _pandas
  20. if _pandas is None:
  21. import pandas
  22. _pandas = pandas
  23. return _pandas
  24. @DeveloperAPI
  25. class BatchFormat(str, Enum):
  26. PANDAS = "pandas"
  27. # TODO: Remove once Arrow is deprecated as user facing batch format
  28. ARROW = "arrow"
  29. NUMPY = "numpy" # Either a single numpy array or a Dict of numpy arrays.
  30. def _convert_batch_type_to_pandas(
  31. data: DataBatchType,
  32. cast_tensor_columns: bool = False,
  33. ) -> "pd.DataFrame":
  34. """Convert the provided data to a Pandas DataFrame.
  35. Args:
  36. data: Data of type DataBatchType
  37. cast_tensor_columns: Whether tensor columns should be cast to NumPy ndarrays.
  38. Returns:
  39. A pandas Dataframe representation of the input data.
  40. """
  41. pd = _lazy_import_pandas()
  42. if isinstance(data, np.ndarray):
  43. data = pd.DataFrame({TENSOR_COLUMN_NAME: _ndarray_to_column(data)})
  44. elif isinstance(data, dict):
  45. tensor_dict = {}
  46. for col_name, col in data.items():
  47. if not isinstance(col, np.ndarray):
  48. raise ValueError(
  49. "All values in the provided dict must be of type "
  50. f"np.ndarray. Found type {type(col)} for key {col_name} "
  51. f"instead."
  52. )
  53. tensor_dict[col_name] = _ndarray_to_column(col)
  54. data = pd.DataFrame(tensor_dict)
  55. elif pyarrow is not None and isinstance(data, pyarrow.Table):
  56. data = data.to_pandas()
  57. elif not isinstance(data, pd.DataFrame):
  58. raise ValueError(
  59. f"Received data of type: {type(data)}, but expected it to be one "
  60. f"of {DataBatchType}"
  61. )
  62. if cast_tensor_columns:
  63. data = _cast_tensor_columns_to_ndarrays(data)
  64. return data
  65. def _convert_pandas_to_batch_type(
  66. data: "pd.DataFrame",
  67. type: BatchFormat,
  68. cast_tensor_columns: bool = False,
  69. ) -> DataBatchType:
  70. """Convert the provided Pandas dataframe to the provided ``type``.
  71. Args:
  72. data: A Pandas DataFrame
  73. type: The specific ``BatchFormat`` to convert to.
  74. cast_tensor_columns: Whether tensor columns should be cast to our tensor
  75. extension type.
  76. Returns:
  77. The input data represented with the provided type.
  78. """
  79. if cast_tensor_columns:
  80. data = _cast_ndarray_columns_to_tensor_extension(data)
  81. if type == BatchFormat.PANDAS:
  82. return data
  83. elif type == BatchFormat.NUMPY:
  84. if len(data.columns) == 1:
  85. # If just a single column, return as a single numpy array.
  86. return data.iloc[:, 0].to_numpy()
  87. else:
  88. # Else return as a dict of numpy arrays.
  89. output_dict = {}
  90. for column in data:
  91. output_dict[column] = data[column].to_numpy()
  92. return output_dict
  93. elif type == BatchFormat.ARROW:
  94. if not pyarrow:
  95. raise ValueError(
  96. "Attempted to convert data to Pyarrow Table but Pyarrow "
  97. "is not installed. Please do `pip install pyarrow` to "
  98. "install Pyarrow."
  99. )
  100. return pyarrow.Table.from_pandas(data)
  101. else:
  102. raise ValueError(
  103. f"Received type {type}, but expected it to be one of {DataBatchType}"
  104. )
  105. def _convert_batch_type_to_numpy(
  106. data: DataBatchType,
  107. ) -> Union[np.ndarray, Dict[str, np.ndarray]]:
  108. """Convert the provided data to a NumPy ndarray or dict of ndarrays.
  109. Args:
  110. data: Data of type DataBatchType
  111. Returns:
  112. A numpy representation of the input data.
  113. """
  114. pd = _lazy_import_pandas()
  115. if isinstance(data, np.ndarray):
  116. return data
  117. elif isinstance(data, dict):
  118. for col_name, col in data.items():
  119. if not isinstance(col, np.ndarray):
  120. raise ValueError(
  121. "All values in the provided dict must be of type "
  122. f"np.ndarray. Found type {type(col)} for key {col_name} "
  123. f"instead."
  124. )
  125. return data
  126. elif pyarrow is not None and isinstance(data, pyarrow.Table):
  127. from ray.data._internal.arrow_ops import transform_pyarrow
  128. from ray.data._internal.tensor_extensions.arrow import (
  129. get_arrow_extension_fixed_shape_tensor_types,
  130. )
  131. column_values_ndarrays = []
  132. for col in data.columns:
  133. # Combine columnar values arrays to make these contiguous
  134. # (making them compatible with numpy format)
  135. combined_array = transform_pyarrow.combine_chunked_array(col)
  136. column_values_ndarrays.append(
  137. transform_pyarrow.to_numpy(combined_array, zero_copy_only=False)
  138. )
  139. arrow_fixed_shape_tensor_types = get_arrow_extension_fixed_shape_tensor_types()
  140. # NOTE: This branch is here for backwards-compatibility
  141. if data.column_names == [TENSOR_COLUMN_NAME] and (
  142. isinstance(data.schema.types[0], arrow_fixed_shape_tensor_types)
  143. ):
  144. return column_values_ndarrays[0]
  145. return dict(zip(data.column_names, column_values_ndarrays))
  146. elif isinstance(data, pd.DataFrame):
  147. return _convert_pandas_to_batch_type(data, BatchFormat.NUMPY)
  148. else:
  149. raise ValueError(
  150. f"Received data of type: {type(data)}, but expected it to be one "
  151. f"of {DataBatchType}"
  152. )
  153. def _ndarray_to_column(arr: np.ndarray) -> Union["pd.Series", List[np.ndarray]]:
  154. """Convert a NumPy ndarray into an appropriate column format for insertion into a
  155. pandas DataFrame.
  156. If conversion to a pandas Series fails (e.g. if the ndarray is multi-dimensional),
  157. fall back to a list of NumPy ndarrays.
  158. """
  159. pd = _lazy_import_pandas()
  160. try:
  161. # Try to convert to Series, falling back to a list conversion if this fails
  162. # (e.g. if the ndarray is multi-dimensional).
  163. return pd.Series(arr)
  164. except ValueError:
  165. return list(arr)
  166. def _unwrap_ndarray_object_type_if_needed(arr: np.ndarray) -> np.ndarray:
  167. """Unwrap an object-dtyped NumPy ndarray containing ndarray pointers into a single
  168. contiguous ndarray, if needed/possible.
  169. """
  170. if arr.dtype.type is np.object_:
  171. try:
  172. # Try to convert the NumPy ndarray to a non-object dtype.
  173. arr = np.array([np.asarray(v) for v in arr])
  174. except Exception:
  175. # This may fail if the subndarrays are of heterogeneous shape
  176. pass
  177. return arr
  178. def _cast_ndarray_columns_to_tensor_extension(df: "pd.DataFrame") -> "pd.DataFrame":
  179. """
  180. Cast all NumPy ndarray columns in df to our tensor extension type, TensorArray.
  181. """
  182. pd = _lazy_import_pandas()
  183. try:
  184. SettingWithCopyWarning = pd.core.common.SettingWithCopyWarning
  185. except AttributeError:
  186. # SettingWithCopyWarning was moved to pd.errors in Pandas 1.5.0.
  187. SettingWithCopyWarning = pd.errors.SettingWithCopyWarning
  188. from ray.data._internal.tensor_extensions.pandas import (
  189. TensorArray,
  190. column_needs_tensor_extension,
  191. )
  192. # Try to convert any ndarray columns to TensorArray columns.
  193. # TODO(Clark): Once Pandas supports registering extension types for type
  194. # inference on construction, implement as much for NumPy ndarrays and remove
  195. # this. See https://github.com/pandas-dev/pandas/issues/41848
  196. # TODO(Clark): Optimize this with propagated DataFrame metadata containing a list of
  197. # column names containing tensor columns, to make this an O(# of tensor columns)
  198. # check rather than the current O(# of columns) check.
  199. for col_name, col in df.items():
  200. if column_needs_tensor_extension(col):
  201. try:
  202. # Suppress Pandas warnings:
  203. # https://github.com/ray-project/ray/issues/29270
  204. # We actually want in-place operations so we surpress this warning.
  205. # https://stackoverflow.com/a/74193599
  206. with warnings.catch_warnings():
  207. warnings.simplefilter("ignore", category=FutureWarning)
  208. warnings.simplefilter("ignore", category=SettingWithCopyWarning)
  209. df[col_name] = TensorArray(col)
  210. except Exception as e:
  211. raise ValueError(
  212. f"Tried to cast column {col_name} to the TensorArray tensor "
  213. "extension type but the conversion failed. To disable "
  214. "automatic casting to this tensor extension, set "
  215. "ctx = DataContext.get_current(); "
  216. "ctx.enable_tensor_extension_casting = False."
  217. ) from e
  218. return df
  219. def _cast_tensor_columns_to_ndarrays(df: "pd.DataFrame") -> "pd.DataFrame":
  220. """Cast all tensor extension columns in df to NumPy ndarrays."""
  221. pd = _lazy_import_pandas()
  222. try:
  223. SettingWithCopyWarning = pd.core.common.SettingWithCopyWarning
  224. except AttributeError:
  225. # SettingWithCopyWarning was moved to pd.errors in Pandas 1.5.0.
  226. SettingWithCopyWarning = pd.errors.SettingWithCopyWarning
  227. from ray.data._internal.tensor_extensions.pandas import TensorDtype
  228. # Try to convert any tensor extension columns to ndarray columns.
  229. # TODO(Clark): Optimize this with propagated DataFrame metadata containing a list of
  230. # column names containing tensor columns, to make this an O(# of tensor columns)
  231. # check rather than the current O(# of columns) check.
  232. for col_name, col in df.items():
  233. if isinstance(col.dtype, TensorDtype):
  234. # Suppress Pandas warnings:
  235. # https://github.com/ray-project/ray/issues/29270
  236. # We actually want in-place operations so we surpress this warning.
  237. # https://stackoverflow.com/a/74193599
  238. with warnings.catch_warnings():
  239. warnings.simplefilter("ignore", category=FutureWarning)
  240. warnings.simplefilter("ignore", category=SettingWithCopyWarning)
  241. df[col_name] = list(col.to_numpy())
  242. return df