torch_utils.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525
  1. import warnings
  2. from concurrent.futures import ThreadPoolExecutor
  3. from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
  4. import numpy as np
  5. import pandas as pd
  6. import pyarrow
  7. import torch
  8. from ray._private.ray_constants import env_bool
  9. from ray.data.collate_fn import (
  10. TensorBatchReturnType,
  11. TensorBatchType,
  12. _is_nested_tensor_sequence,
  13. _is_tensor,
  14. _is_tensor_mapping,
  15. _is_tensor_sequence,
  16. _is_tensor_sequence_mapping,
  17. )
  18. from ray.data.util.data_batch_conversion import _unwrap_ndarray_object_type_if_needed
  19. # Default non-blocking transfer for tensors.
  20. DEFAULT_TENSOR_NON_BLOCKING_TRANSFER = env_bool(
  21. "RAY_AIR_DEFAULT_TENSOR_NON_BLOCKING_TRANSFER",
  22. True,
  23. )
  24. def convert_pandas_to_torch_tensor(
  25. data_batch: pd.DataFrame,
  26. columns: Optional[Union[List[str], List[List[str]]]] = None,
  27. column_dtypes: Optional[Union[torch.dtype, List[torch.dtype]]] = None,
  28. unsqueeze: bool = True,
  29. ) -> Union[torch.Tensor, List[torch.Tensor]]:
  30. """Converts a Pandas dataframe to a torch Tensor or list of torch Tensors.
  31. The format of the return type will match the format of ``columns``. If a
  32. list of columns is provided, the return type will be a single tensor. If
  33. ``columns`` is a list of lists, then the return type will be a list of
  34. tensors.
  35. Args:
  36. data_batch: The pandas dataframe to convert to a
  37. torch tensor.
  38. columns:
  39. The names of the columns in the dataframe to include in the
  40. torch tensor. If this arg is a List[List[str]], then the return
  41. type will be a List of tensors. This is useful for multi-input
  42. models. If None, then use all columns in the ``data_batch``.
  43. column_dtypes: The
  44. torch dtype to use for the tensor. If set to None,
  45. then automatically infer the dtype.
  46. unsqueeze: If set to True, the tensors
  47. will be unsqueezed (reshaped to (N, 1)) before being concatenated into
  48. the final tensor. Otherwise, they will be left as is, that is
  49. (N, ). Defaults to True.
  50. Returns:
  51. Either a torch tensor of size (N, len(columns)) where N is the
  52. number of rows in the ``data_batch`` Dataframe, or a list of
  53. tensors, where the size of item i is (N, len(columns[i])).
  54. """
  55. multi_input = columns and (isinstance(columns[0], (list, tuple)))
  56. if not multi_input and column_dtypes and not isinstance(column_dtypes, torch.dtype):
  57. raise TypeError(
  58. "If `columns` is a list of strings, "
  59. "`column_dtypes` must be None or a single `torch.dtype`."
  60. f"Got {type(column_dtypes)} instead."
  61. )
  62. columns = columns if columns else []
  63. def tensorize(vals, dtype):
  64. """This recursive function allows to convert pyarrow List dtypes
  65. to multi-dimensional tensors."""
  66. if isinstance(vals, pd.api.extensions.ExtensionArray):
  67. # torch.as_tensor() does not yet support the __array__ protocol, so we need
  68. # to convert extension arrays to ndarrays manually before converting to a
  69. # Torch tensor.
  70. # See https://github.com/pytorch/pytorch/issues/51156.
  71. vals = vals.to_numpy()
  72. if vals.dtype.type is np.object_:
  73. # Column has an object dtype which Torch can't handle, so we try to
  74. # tensorize each column element and then stack the resulting tensors.
  75. tensors = [tensorize(x, dtype) for x in vals]
  76. try:
  77. return torch.stack(tensors)
  78. except RuntimeError:
  79. # NOTE: RuntimeError is raised when trying to stack ragged tensors.
  80. # Try to coerce the tensor to a nested tensor, if possible.
  81. # If this fails, the exception will be propagated up to the caller.
  82. return torch.nested_tensor(tensors)
  83. else:
  84. return torch.as_tensor(vals, dtype=dtype)
  85. def get_tensor_for_columns(columns, dtype):
  86. feature_tensors = []
  87. if columns:
  88. batch = data_batch[columns]
  89. else:
  90. batch = data_batch
  91. for col in batch.columns:
  92. col_vals = batch[col].values
  93. try:
  94. t = tensorize(col_vals, dtype=dtype)
  95. except Exception as e:
  96. raise ValueError(
  97. f"Failed to convert column {col} to a Torch Tensor of dtype "
  98. f"{dtype}. See above exception chain for the exact failure."
  99. ) from e
  100. if unsqueeze:
  101. t = t.unsqueeze(1)
  102. feature_tensors.append(t)
  103. if len(feature_tensors) > 1:
  104. feature_tensor = torch.cat(feature_tensors, dim=1)
  105. else:
  106. feature_tensor = feature_tensors[0]
  107. return feature_tensor
  108. if multi_input:
  109. if type(column_dtypes) not in [list, tuple]:
  110. column_dtypes = [column_dtypes] * len(columns)
  111. return [
  112. get_tensor_for_columns(columns=subcolumns, dtype=dtype)
  113. for subcolumns, dtype in zip(columns, column_dtypes)
  114. ]
  115. else:
  116. return get_tensor_for_columns(columns=columns, dtype=column_dtypes)
  117. def convert_ndarray_to_torch_tensor(
  118. ndarray: np.ndarray,
  119. dtype: Optional[torch.dtype] = None,
  120. device: Optional[Union[str, "torch.device"]] = None,
  121. pin_memory: bool = False,
  122. ) -> torch.Tensor:
  123. """Convert a NumPy ndarray to a Torch Tensor.
  124. Args:
  125. ndarray: A NumPy ndarray that we wish to convert to a Torch Tensor.
  126. dtype: A Torch dtype for the created tensor; if None, the dtype will be
  127. inferred from the NumPy ndarray data.
  128. device: The device on which the tensor(s) should be placed; if None, the Torch
  129. tensor(s) will be constructed on the CPU.
  130. pin_memory: Whether to pin the memory of the created tensors.
  131. Returns:
  132. A Torch Tensor.
  133. """
  134. ndarray = _unwrap_ndarray_object_type_if_needed(ndarray)
  135. # Object dtype cannot be converted into PyTorch Tensor.
  136. if ndarray.dtype.type is np.object_:
  137. raise RuntimeError(
  138. "Numpy array of object dtype cannot be converted to a Torch Tensor. This "
  139. "may because the numpy array is a ragged tensor--it contains items of "
  140. "different sizes. If using `iter_torch_batches()` API, you can pass in a "
  141. "`collate_fn` argument to specify custom logic to convert the Numpy array "
  142. "batch to a Torch tensor batch."
  143. )
  144. # The numpy array is not always writeable as it can come from the Ray object store.
  145. # Numpy will throw a verbose warning here, which we suppress, as we don't write
  146. # to the tensors. We also don't want to copy the array to avoid memory overhead.
  147. # Original warning: https://github.com/pytorch/pytorch/blob/v1.13.0/
  148. # torch/csrc/utils/tensor_numpy.cpp#L198-L206
  149. with warnings.catch_warnings():
  150. warnings.simplefilter("ignore")
  151. result = torch.as_tensor(ndarray, dtype=dtype, device=device)
  152. if pin_memory:
  153. assert result.device.type == "cpu", (
  154. "Pin memory is only supported for CPU tensors. "
  155. f"Got device: {result.device} and pin_memory: {pin_memory}."
  156. )
  157. result = result.pin_memory()
  158. return result
  159. def convert_ndarray_batch_to_torch_tensor_batch(
  160. ndarrays: Union[np.ndarray, Dict[str, np.ndarray]],
  161. dtypes: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None,
  162. device: Optional[Union[str, "torch.device"]] = None,
  163. pin_memory: bool = False,
  164. ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
  165. """Convert a NumPy ndarray batch to a Torch Tensor batch.
  166. Args:
  167. ndarrays: A (dict of) NumPy ndarray(s) that we wish to convert to a Torch Tensor.
  168. dtypes: A (dict of) Torch dtype(s) for the created tensor; if None, the dtype
  169. will be inferred from the NumPy ndarray data.
  170. device: The device on which the tensor(s) should be placed; if None, the Torch
  171. tensor(s) will be constructed on the CPU.
  172. pin_memory: Whether to pin the memory of the created tensors.
  173. Returns:
  174. A (dict of) Torch Tensor(s).
  175. """
  176. if isinstance(ndarrays, np.ndarray):
  177. # Single-tensor case.
  178. if isinstance(dtypes, dict):
  179. if len(dtypes) != 1:
  180. raise ValueError(
  181. "When constructing a single-tensor batch, only a single dtype "
  182. f"should be given, instead got: {dtypes}"
  183. )
  184. dtypes = next(iter(dtypes.values()))
  185. batch = convert_ndarray_to_torch_tensor(
  186. ndarrays,
  187. dtype=dtypes,
  188. device=device,
  189. pin_memory=pin_memory,
  190. )
  191. else:
  192. # Multi-tensor case.
  193. batch = {
  194. col_name: convert_ndarray_to_torch_tensor(
  195. col_ndarray,
  196. dtype=dtypes[col_name] if isinstance(dtypes, dict) else dtypes,
  197. device=device,
  198. pin_memory=pin_memory,
  199. )
  200. for col_name, col_ndarray in ndarrays.items()
  201. }
  202. return batch
  203. def convert_ndarray_list_to_torch_tensor_list(
  204. ndarrays: Dict[str, List[np.ndarray]],
  205. dtypes: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None,
  206. device: Optional[Union[str, "torch.device"]] = None,
  207. pin_memory: bool = False,
  208. ) -> Dict[str, List[torch.Tensor]]:
  209. """Convert a dict mapping column names to lists of ndarrays to Torch Tensors.
  210. Args:
  211. ndarrays: A dict mapping column names to lists of ndarrays that we wish to convert
  212. to Torch Tensors.
  213. dtypes: A (dict of) Torch dtype(s) for the created tensors; if None, the dtype
  214. will be inferred from the NumPy ndarray data.
  215. device: The device on which the tensor(s) should be placed; if None, the Torch
  216. tensor(s) will be constructed on the CPU.
  217. pin_memory: Whether to pin the memory of the created tensors.
  218. Returns:
  219. A dict mapping column names to lists of Tensors.
  220. """
  221. return {
  222. col_name: [
  223. convert_ndarray_batch_to_torch_tensor_batch(
  224. ndarray,
  225. dtypes=dtypes[col_name] if isinstance(dtypes, dict) else dtypes,
  226. device=device,
  227. pin_memory=pin_memory,
  228. )
  229. for ndarray in col_ndarrays
  230. ]
  231. for col_name, col_ndarrays in ndarrays.items()
  232. }
  233. def arrow_batch_to_tensors(
  234. batch: pyarrow.Table,
  235. dtypes: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None,
  236. combine_chunks: bool = False,
  237. pin_memory: bool = False,
  238. threadpool: Optional[ThreadPoolExecutor] = None,
  239. ) -> Union[Dict[str, torch.Tensor], Dict[str, List[torch.Tensor]]]:
  240. """Convert PyArrow batch to PyTorch tensors.
  241. Args:
  242. batch: PyArrow batch to convert
  243. dtypes: A (dict of) Torch dtype(s) for the created tensors; if None, the dtype
  244. will be inferred from the NumPy ndarray data.
  245. combine_chunks: If True, combine chunks in Arrow batch before converting to
  246. tensors.
  247. pin_memory: Whether to pin the memory of the created tensors.
  248. threadpool: Optional ThreadPoolExecutor for parallel processing. If provided,
  249. columns/arrays will be processed in parallel. If None, processing is
  250. sequential.
  251. Returns:
  252. When combine_chunks=True: A dictionary of column name to single tensor.
  253. When combine_chunks=False: A dictionary of column name to list of tensors.
  254. """
  255. from ray.data._internal.arrow_block import ArrowBlockAccessor
  256. from ray.data._internal.arrow_ops import transform_pyarrow
  257. if combine_chunks:
  258. numpy_batch = ArrowBlockAccessor(batch).to_batch_format("numpy")
  259. num_columns = len(numpy_batch)
  260. if num_columns > 1 and threadpool is not None:
  261. # Process columns in parallel using provided threadpool
  262. def process_column(
  263. col_name_col_array: Tuple[str, np.ndarray]
  264. ) -> Tuple[str, torch.Tensor]:
  265. col_name, col_array = col_name_col_array
  266. return col_name, convert_ndarray_batch_to_torch_tensor_batch(
  267. col_array,
  268. dtypes=dtypes[col_name] if isinstance(dtypes, dict) else dtypes,
  269. pin_memory=pin_memory,
  270. )
  271. # Submit all columns to threadpool and collect results
  272. processed_cols = threadpool.map(process_column, numpy_batch.items())
  273. return dict(processed_cols)
  274. else:
  275. # Sequential processing for single column or single worker
  276. return {
  277. col_name: convert_ndarray_batch_to_torch_tensor_batch(
  278. col_array,
  279. dtypes=dtypes[col_name] if isinstance(dtypes, dict) else dtypes,
  280. pin_memory=pin_memory,
  281. )
  282. for col_name, col_array in numpy_batch.items()
  283. }
  284. else:
  285. numpy_list = transform_pyarrow.table_to_numpy_dict_chunked(
  286. batch,
  287. )
  288. # Count total number of arrays across all columns
  289. total_arrays = sum(len(arrays) for arrays in numpy_list.values())
  290. num_columns = len(numpy_list)
  291. if total_arrays > 1 and threadpool is not None:
  292. # Process arrays in parallel using provided threadpool
  293. def process_array(
  294. array_item: Tuple[str, int, np.ndarray]
  295. ) -> Tuple[str, int, torch.Tensor]:
  296. col_name, array_index, array = array_item
  297. return (
  298. col_name,
  299. array_index,
  300. convert_ndarray_batch_to_torch_tensor_batch(
  301. array,
  302. dtypes=dtypes[col_name] if isinstance(dtypes, dict) else dtypes,
  303. pin_memory=pin_memory,
  304. ),
  305. )
  306. # Flatten arrays with column name and index for parallel processing
  307. array_items = [
  308. (col_name, idx, array)
  309. for col_name, arrays in numpy_list.items()
  310. for idx, array in enumerate(arrays)
  311. ]
  312. # Submit all arrays to threadpool and collect results
  313. processed_arrays = list(threadpool.map(process_array, array_items))
  314. # Initialize result with all columns from numpy_list, including empty ones
  315. # Pre-allocate lists of the correct size for each column
  316. result: Dict[str, List[torch.Tensor]] = {
  317. col_name: [None] * len(arrays)
  318. for col_name, arrays in numpy_list.items()
  319. }
  320. # Populate result with processed tensors
  321. for col_name, array_index, tensor in processed_arrays:
  322. result[col_name][array_index] = tensor
  323. return result
  324. else:
  325. # Sequential processing
  326. return convert_ndarray_list_to_torch_tensor_list(
  327. numpy_list,
  328. dtypes=dtypes,
  329. pin_memory=pin_memory,
  330. )
  331. @torch.no_grad()
  332. def concat_tensors_to_device(
  333. tensor_sequence: Sequence[torch.Tensor],
  334. device: Optional[Union[str, "torch.device"]] = None,
  335. non_blocking: bool = DEFAULT_TENSOR_NON_BLOCKING_TRANSFER,
  336. ) -> torch.Tensor:
  337. """Stack sequence of tensors into a contiguous GPU tensor.
  338. Args:
  339. tensor_sequence: Sequence of tensors to stack
  340. device: The device to move tensors to. If None, tensors are not moved.
  341. non_blocking: If True, perform device transfer without forcing a
  342. synchronization.
  343. Returns:
  344. A contiguous tensor on the target device
  345. """
  346. # Assumes tensors have the same shape/dtype
  347. assert (
  348. tensor_sequence
  349. ), f"Cannot stack empty sequence of tensors. Received: {tensor_sequence}"
  350. assert all(
  351. isinstance(t, torch.Tensor) for t in tensor_sequence
  352. ), "All items must be torch.Tensor. Found invalid types: " + str(
  353. [type(t) for t in tensor_sequence if not isinstance(t, torch.Tensor)]
  354. )
  355. # If there is only one tensor and its device already matches, return it directly.
  356. if len(tensor_sequence) == 1 and (
  357. device is None or tensor_sequence[0].device == torch.device(device)
  358. ):
  359. return tensor_sequence[0]
  360. first_dtype = tensor_sequence[0].dtype
  361. assert all(t.dtype == first_dtype for t in tensor_sequence), (
  362. "All tensors must have the same dtype. "
  363. f"Expected: {first_dtype}, got: {[t.dtype for t in tensor_sequence]}"
  364. )
  365. first_shape = tensor_sequence[0].shape[1:]
  366. assert all(t.shape[1:] == first_shape for t in tensor_sequence), (
  367. "All tensors must have the same shape[1:]. "
  368. f"Expected: {first_shape}, got: {[t.shape[1:] for t in tensor_sequence]}"
  369. )
  370. first = tensor_sequence[0]
  371. dtype = first.dtype
  372. shape_tail = first.shape[1:]
  373. total_rows = sum(t.shape[0] for t in tensor_sequence)
  374. # Allocate an empty Tensor on device
  375. result = torch.empty((total_rows, *shape_tail), dtype=dtype, device=device)
  376. row_start = 0
  377. for t in tensor_sequence:
  378. row_end = row_start + t.shape[0]
  379. result[row_start:row_end].copy_(t, non_blocking=non_blocking)
  380. row_start = row_end
  381. return result
  382. def _get_type_str(batch: Any) -> str:
  383. """Get a string representation of the possibly nested type of the batch.
  384. >>> import torch
  385. >>> _get_type_str([1, 2, "???"])
  386. 'list[int | str]'
  387. >>> _get_type_str({"a": [1, 2, 3], "b": 4})
  388. 'dict[str, int | list[int]]'
  389. >>> _get_type_str({"a": torch.tensor(1), "b": [torch.tensor(2)]})
  390. 'dict[str, Tensor | list[Tensor]]'
  391. >>> _get_type_str({"a": torch.tensor(1), "b": {"c": torch.tensor(2)}})
  392. 'dict[str, Tensor | dict[str, Tensor]]'
  393. """
  394. curr_type = type(batch).__name__
  395. if isinstance(batch, (list, tuple)):
  396. val_types = " | ".join(sorted({_get_type_str(v) for v in batch}))
  397. invalid_type_str = f"{curr_type}[{val_types}]"
  398. elif isinstance(batch, dict):
  399. val_types = " | ".join(sorted({_get_type_str(v) for v in batch.values()}))
  400. invalid_type_str = f"{curr_type}[str, {val_types}]"
  401. else:
  402. invalid_type_str = curr_type
  403. return invalid_type_str
  404. @torch.no_grad()
  405. def move_tensors_to_device(
  406. batch: TensorBatchType,
  407. device: Optional[Union[str, "torch.device"]] = None,
  408. non_blocking: bool = DEFAULT_TENSOR_NON_BLOCKING_TRANSFER,
  409. ) -> TensorBatchReturnType:
  410. """Move tensors to the specified device.
  411. Concatenate nested lists/tuples of tensors along the first (batch) dimension.
  412. For example, for the input
  413. ((feature_0_chunk_0,), (feature_1_chunk_0, feature_1_chunk_1))
  414. the output will be (feature_0_chunk_0, feature_1_chunk_0+1)
  415. where each feature is concatenated along the batch dimension.
  416. Args:
  417. batch: A tensor or collection of tensors to move to device. Can be:
  418. - A single tensor
  419. - A sequence of tensors
  420. - A sequence of sequences of tensors. The inner sequence of tensors is
  421. combined during GPU transfer.
  422. - A mapping (e.g., dict) of keys to tensors or sequences of tensors. The
  423. sequence of tensors is combined during GPU transfer.
  424. device: The device to move tensors to. If None, tensors are not moved.
  425. non_blocking: If True, perform device transfer without forcing a
  426. synchronization.
  427. Returns:
  428. The input tensors moved to the specified device
  429. """
  430. if device is None:
  431. return batch
  432. if _is_tensor(batch):
  433. return batch.to(device, non_blocking=non_blocking)
  434. elif _is_tensor_sequence(batch):
  435. return type(batch)([t.to(device, non_blocking=non_blocking) for t in batch])
  436. elif _is_nested_tensor_sequence(batch):
  437. return type(batch)(
  438. [concat_tensors_to_device(t, device, non_blocking) for t in batch]
  439. )
  440. elif _is_tensor_mapping(batch):
  441. return {k: t.to(device, non_blocking=non_blocking) for k, t in batch.items()}
  442. elif _is_tensor_sequence_mapping(batch):
  443. return {
  444. k: concat_tensors_to_device(v, device, non_blocking)
  445. for k, v in batch.items()
  446. }
  447. else:
  448. raise ValueError(
  449. f"Invalid input type: {_get_type_str(batch)}.\n"
  450. "Expected one of the following: "
  451. "torch.Tensor, "
  452. "List/Tuple[torch.Tensor], "
  453. "Dict[str, torch.Tensor], "
  454. "Mapping[str, List/Tuple[torch.Tensor]]"
  455. )