data_batch_conversion.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. import warnings
  2. from enum import Enum
  3. from typing import TYPE_CHECKING, Dict, List, Union
  4. import numpy as np
  5. from ray.air.constants import TENSOR_COLUMN_NAME
  6. from ray.air.data_batch_type import DataBatchType
  7. from ray.util.annotations import Deprecated, DeveloperAPI
  8. if TYPE_CHECKING:
  9. import pandas as pd
  10. # TODO: Consolidate data conversion edges for arrow bug workaround.
  11. try:
  12. import pyarrow
  13. except ImportError:
  14. pyarrow = None
  15. # Lazy import to avoid ray init failures without pandas installed and allow
  16. # dataset to import modules in this file.
  17. _pandas = None
  18. def _lazy_import_pandas():
  19. global _pandas
  20. if _pandas is None:
  21. import pandas
  22. _pandas = pandas
  23. return _pandas
  24. @DeveloperAPI
  25. class BatchFormat(str, Enum):
  26. PANDAS = "pandas"
  27. # TODO: Remove once Arrow is deprecated as user facing batch format
  28. ARROW = "arrow"
  29. NUMPY = "numpy" # Either a single numpy array or a Dict of numpy arrays.
  30. @DeveloperAPI
  31. class BlockFormat(str, Enum):
  32. """Internal Dataset block format enum."""
  33. PANDAS = "pandas"
  34. ARROW = "arrow"
  35. SIMPLE = "simple"
  36. def _convert_batch_type_to_pandas(
  37. data: DataBatchType,
  38. cast_tensor_columns: bool = False,
  39. ) -> "pd.DataFrame":
  40. """Convert the provided data to a Pandas DataFrame.
  41. Args:
  42. data: Data of type DataBatchType
  43. cast_tensor_columns: Whether tensor columns should be cast to NumPy ndarrays.
  44. Returns:
  45. A pandas Dataframe representation of the input data.
  46. """
  47. pd = _lazy_import_pandas()
  48. if isinstance(data, np.ndarray):
  49. data = pd.DataFrame({TENSOR_COLUMN_NAME: _ndarray_to_column(data)})
  50. elif isinstance(data, dict):
  51. tensor_dict = {}
  52. for col_name, col in data.items():
  53. if not isinstance(col, np.ndarray):
  54. raise ValueError(
  55. "All values in the provided dict must be of type "
  56. f"np.ndarray. Found type {type(col)} for key {col_name} "
  57. f"instead."
  58. )
  59. tensor_dict[col_name] = _ndarray_to_column(col)
  60. data = pd.DataFrame(tensor_dict)
  61. elif pyarrow is not None and isinstance(data, pyarrow.Table):
  62. data = data.to_pandas()
  63. elif not isinstance(data, pd.DataFrame):
  64. raise ValueError(
  65. f"Received data of type: {type(data)}, but expected it to be one "
  66. f"of {DataBatchType}"
  67. )
  68. if cast_tensor_columns:
  69. data = _cast_tensor_columns_to_ndarrays(data)
  70. return data
  71. def _convert_pandas_to_batch_type(
  72. data: "pd.DataFrame",
  73. type: BatchFormat,
  74. cast_tensor_columns: bool = False,
  75. ) -> DataBatchType:
  76. """Convert the provided Pandas dataframe to the provided ``type``.
  77. Args:
  78. data: A Pandas DataFrame
  79. type: The specific ``BatchFormat`` to convert to.
  80. cast_tensor_columns: Whether tensor columns should be cast to our tensor
  81. extension type.
  82. Returns:
  83. The input data represented with the provided type.
  84. """
  85. if cast_tensor_columns:
  86. data = _cast_ndarray_columns_to_tensor_extension(data)
  87. if type == BatchFormat.PANDAS:
  88. return data
  89. elif type == BatchFormat.NUMPY:
  90. if len(data.columns) == 1:
  91. # If just a single column, return as a single numpy array.
  92. return data.iloc[:, 0].to_numpy()
  93. else:
  94. # Else return as a dict of numpy arrays.
  95. output_dict = {}
  96. for column in data:
  97. output_dict[column] = data[column].to_numpy()
  98. return output_dict
  99. elif type == BatchFormat.ARROW:
  100. if not pyarrow:
  101. raise ValueError(
  102. "Attempted to convert data to Pyarrow Table but Pyarrow "
  103. "is not installed. Please do `pip install pyarrow` to "
  104. "install Pyarrow."
  105. )
  106. return pyarrow.Table.from_pandas(data)
  107. else:
  108. raise ValueError(
  109. f"Received type {type}, but expected it to be one of {DataBatchType}"
  110. )
  111. @Deprecated
  112. def convert_batch_type_to_pandas(
  113. data: DataBatchType,
  114. cast_tensor_columns: bool = False,
  115. ):
  116. """Convert the provided data to a Pandas DataFrame.
  117. This API is deprecated from Ray 2.4.
  118. Args:
  119. data: Data of type DataBatchType
  120. cast_tensor_columns: Whether tensor columns should be cast to NumPy ndarrays.
  121. Returns:
  122. A pandas Dataframe representation of the input data.
  123. """
  124. warnings.warn(
  125. "`convert_batch_type_to_pandas` is deprecated as a developer API "
  126. "starting from Ray 2.4. All batch format conversions should be "
  127. "done manually instead of relying on this API.",
  128. PendingDeprecationWarning,
  129. )
  130. return _convert_batch_type_to_pandas(
  131. data=data, cast_tensor_columns=cast_tensor_columns
  132. )
  133. @Deprecated
  134. def convert_pandas_to_batch_type(
  135. data: "pd.DataFrame",
  136. type: BatchFormat,
  137. cast_tensor_columns: bool = False,
  138. ):
  139. """Convert the provided Pandas dataframe to the provided ``type``.
  140. Args:
  141. data: A Pandas DataFrame
  142. type: The specific ``BatchFormat`` to convert to.
  143. cast_tensor_columns: Whether tensor columns should be cast to our tensor
  144. extension type.
  145. Returns:
  146. The input data represented with the provided type.
  147. """
  148. warnings.warn(
  149. "`convert_pandas_to_batch_type` is deprecated as a developer API "
  150. "starting from Ray 2.4. All batch format conversions should be "
  151. "done manually instead of relying on this API.",
  152. PendingDeprecationWarning,
  153. )
  154. return _convert_pandas_to_batch_type(
  155. data=data, type=type, cast_tensor_columns=cast_tensor_columns
  156. )
  157. def _convert_batch_type_to_numpy(
  158. data: DataBatchType,
  159. ) -> Union[np.ndarray, Dict[str, np.ndarray]]:
  160. """Convert the provided data to a NumPy ndarray or dict of ndarrays.
  161. Args:
  162. data: Data of type DataBatchType
  163. Returns:
  164. A numpy representation of the input data.
  165. """
  166. pd = _lazy_import_pandas()
  167. if isinstance(data, np.ndarray):
  168. return data
  169. elif isinstance(data, dict):
  170. for col_name, col in data.items():
  171. if not isinstance(col, np.ndarray):
  172. raise ValueError(
  173. "All values in the provided dict must be of type "
  174. f"np.ndarray. Found type {type(col)} for key {col_name} "
  175. f"instead."
  176. )
  177. return data
  178. elif pyarrow is not None and isinstance(data, pyarrow.Table):
  179. from ray.data._internal.arrow_ops import transform_pyarrow
  180. from ray.data._internal.tensor_extensions.arrow import (
  181. get_arrow_extension_fixed_shape_tensor_types,
  182. )
  183. column_values_ndarrays = []
  184. for col in data.columns:
  185. # Combine columnar values arrays to make these contiguous
  186. # (making them compatible with numpy format)
  187. combined_array = transform_pyarrow.combine_chunked_array(col)
  188. column_values_ndarrays.append(
  189. transform_pyarrow.to_numpy(combined_array, zero_copy_only=False)
  190. )
  191. arrow_fixed_shape_tensor_types = get_arrow_extension_fixed_shape_tensor_types()
  192. # NOTE: This branch is here for backwards-compatibility
  193. if data.column_names == [TENSOR_COLUMN_NAME] and (
  194. isinstance(data.schema.types[0], arrow_fixed_shape_tensor_types)
  195. ):
  196. return column_values_ndarrays[0]
  197. return dict(zip(data.column_names, column_values_ndarrays))
  198. elif isinstance(data, pd.DataFrame):
  199. return _convert_pandas_to_batch_type(data, BatchFormat.NUMPY)
  200. else:
  201. raise ValueError(
  202. f"Received data of type: {type(data)}, but expected it to be one "
  203. f"of {DataBatchType}"
  204. )
  205. def _ndarray_to_column(arr: np.ndarray) -> Union["pd.Series", List[np.ndarray]]:
  206. """Convert a NumPy ndarray into an appropriate column format for insertion into a
  207. pandas DataFrame.
  208. If conversion to a pandas Series fails (e.g. if the ndarray is multi-dimensional),
  209. fall back to a list of NumPy ndarrays.
  210. """
  211. pd = _lazy_import_pandas()
  212. try:
  213. # Try to convert to Series, falling back to a list conversion if this fails
  214. # (e.g. if the ndarray is multi-dimensional).
  215. return pd.Series(arr)
  216. except ValueError:
  217. return list(arr)
  218. def _unwrap_ndarray_object_type_if_needed(arr: np.ndarray) -> np.ndarray:
  219. """Unwrap an object-dtyped NumPy ndarray containing ndarray pointers into a single
  220. contiguous ndarray, if needed/possible.
  221. """
  222. if arr.dtype.type is np.object_:
  223. try:
  224. # Try to convert the NumPy ndarray to a non-object dtype.
  225. arr = np.array([np.asarray(v) for v in arr])
  226. except Exception:
  227. # This may fail if the subndarrays are of heterogeneous shape
  228. pass
  229. return arr
  230. def _cast_ndarray_columns_to_tensor_extension(df: "pd.DataFrame") -> "pd.DataFrame":
  231. """
  232. Cast all NumPy ndarray columns in df to our tensor extension type, TensorArray.
  233. """
  234. pd = _lazy_import_pandas()
  235. try:
  236. SettingWithCopyWarning = pd.core.common.SettingWithCopyWarning
  237. except AttributeError:
  238. # SettingWithCopyWarning was moved to pd.errors in Pandas 1.5.0.
  239. SettingWithCopyWarning = pd.errors.SettingWithCopyWarning
  240. from ray.data._internal.tensor_extensions.pandas import (
  241. TensorArray,
  242. column_needs_tensor_extension,
  243. )
  244. # Try to convert any ndarray columns to TensorArray columns.
  245. # TODO(Clark): Once Pandas supports registering extension types for type
  246. # inference on construction, implement as much for NumPy ndarrays and remove
  247. # this. See https://github.com/pandas-dev/pandas/issues/41848
  248. # TODO(Clark): Optimize this with propagated DataFrame metadata containing a list of
  249. # column names containing tensor columns, to make this an O(# of tensor columns)
  250. # check rather than the current O(# of columns) check.
  251. for col_name, col in df.items():
  252. if column_needs_tensor_extension(col):
  253. try:
  254. # Suppress Pandas warnings:
  255. # https://github.com/ray-project/ray/issues/29270
  256. # We actually want in-place operations so we surpress this warning.
  257. # https://stackoverflow.com/a/74193599
  258. with warnings.catch_warnings():
  259. warnings.simplefilter("ignore", category=FutureWarning)
  260. warnings.simplefilter("ignore", category=SettingWithCopyWarning)
  261. df[col_name] = TensorArray(col)
  262. except Exception as e:
  263. raise ValueError(
  264. f"Tried to cast column {col_name} to the TensorArray tensor "
  265. "extension type but the conversion failed. To disable "
  266. "automatic casting to this tensor extension, set "
  267. "ctx = DataContext.get_current(); "
  268. "ctx.enable_tensor_extension_casting = False."
  269. ) from e
  270. return df
  271. def _cast_tensor_columns_to_ndarrays(df: "pd.DataFrame") -> "pd.DataFrame":
  272. """Cast all tensor extension columns in df to NumPy ndarrays."""
  273. pd = _lazy_import_pandas()
  274. try:
  275. SettingWithCopyWarning = pd.core.common.SettingWithCopyWarning
  276. except AttributeError:
  277. # SettingWithCopyWarning was moved to pd.errors in Pandas 1.5.0.
  278. SettingWithCopyWarning = pd.errors.SettingWithCopyWarning
  279. from ray.data._internal.tensor_extensions.pandas import TensorDtype
  280. # Try to convert any tensor extension columns to ndarray columns.
  281. # TODO(Clark): Optimize this with propagated DataFrame metadata containing a list of
  282. # column names containing tensor columns, to make this an O(# of tensor columns)
  283. # check rather than the current O(# of columns) check.
  284. for col_name, col in df.items():
  285. if isinstance(col.dtype, TensorDtype):
  286. # Suppress Pandas warnings:
  287. # https://github.com/ray-project/ray/issues/29270
  288. # We actually want in-place operations so we surpress this warning.
  289. # https://stackoverflow.com/a/74193599
  290. with warnings.catch_warnings():
  291. warnings.simplefilter("ignore", category=FutureWarning)
  292. warnings.simplefilter("ignore", category=SettingWithCopyWarning)
  293. df[col_name] = list(col.to_numpy())
  294. return df