collate_fn.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. import abc
  2. from concurrent.futures import ThreadPoolExecutor
  3. from typing import (
  4. TYPE_CHECKING,
  5. Any,
  6. Dict,
  7. Generic,
  8. List,
  9. Mapping,
  10. Optional,
  11. Tuple,
  12. TypeVar,
  13. Union,
  14. )
  15. import numpy as np
  16. from ray._private.ray_constants import env_integer
  17. from ray.util.annotations import DeveloperAPI
  18. if TYPE_CHECKING:
  19. import pandas
  20. import pyarrow
  21. import torch
  22. from ray.data.block import DataBatch
  23. from ray.data.dataset import CollatedData, TorchDeviceType
  24. DataBatchType = TypeVar("DataBatchType", bound="DataBatch")
  25. TensorSequenceType = Union[
  26. List["torch.Tensor"],
  27. Tuple["torch.Tensor", ...],
  28. ]
  29. TensorBatchType = Union[
  30. "torch.Tensor",
  31. TensorSequenceType,
  32. # For nested sequences of tensors, the inner sequence of tensors is combined during
  33. # GPU transfer in `move_tensors_to_device`.
  34. List[TensorSequenceType],
  35. Tuple[TensorSequenceType, ...],
  36. Mapping[str, "torch.Tensor"],
  37. # For mapping (e.g., dict) of keys to sequences of tensors, the sequence of tensors
  38. # is combined during GPU transfer in `move_tensors_to_device`.
  39. Mapping[str, TensorSequenceType],
  40. ]
  41. def _is_tensor(batch: Any) -> bool:
  42. """Check if a batch is a single torch.Tensor."""
  43. import torch
  44. return isinstance(batch, torch.Tensor)
  45. def _is_tensor_sequence(batch: Any) -> bool:
  46. """Check if a batch is a sequence of torch.Tensors.
  47. >>> import torch
  48. >>> _is_tensor_sequence(torch.ones(1))
  49. False
  50. >>> _is_tensor_sequence([torch.ones(1), torch.ones(1)])
  51. True
  52. >>> _is_tensor_sequence((torch.ones(1), torch.ones(1)))
  53. True
  54. >>> _is_tensor_sequence([torch.ones(1), 1])
  55. False
  56. """
  57. return isinstance(batch, (list, tuple)) and all(_is_tensor(t) for t in batch)
  58. def _is_nested_tensor_sequence(batch: Any) -> bool:
  59. """Check if a batch is a sequence of sequences of torch.Tensors.
  60. Stops at one level of nesting.
  61. >>> import torch
  62. >>> _is_nested_tensor_sequence([torch.ones(1), torch.ones(1)])
  63. False
  64. >>> _is_nested_tensor_sequence(
  65. ... ([torch.ones(1), torch.ones(1)], [torch.ones(1)])
  66. ... )
  67. True
  68. """
  69. return isinstance(batch, (list, tuple)) and all(
  70. _is_tensor_sequence(t) for t in batch
  71. )
  72. def _is_tensor_mapping(batch: Any) -> bool:
  73. """Check if a batch is a mapping of keys to torch.Tensors.
  74. >>> import torch
  75. >>> _is_tensor_mapping({"a": torch.ones(1), "b": torch.ones(1)})
  76. True
  77. >>> _is_tensor_mapping({"a": torch.ones(1), "b": [torch.ones(1), torch.ones(1)]})
  78. False
  79. """
  80. return isinstance(batch, Mapping) and all(_is_tensor(v) for v in batch.values())
  81. def _is_tensor_sequence_mapping(batch: Any) -> bool:
  82. """Check if a batch is a mapping of keys to sequences of torch.Tensors.
  83. >>> import torch
  84. >>> _is_tensor_sequence_mapping({"a": torch.ones(1), "b": torch.ones(1)})
  85. False
  86. >>> _is_tensor_sequence_mapping(
  87. ... {"a": (torch.ones(1), torch.ones(1)), "b": [torch.ones(1), torch.ones(1)]}
  88. ... )
  89. True
  90. """
  91. return isinstance(batch, Mapping) and all(
  92. _is_tensor_sequence(v) for v in batch.values()
  93. )
  94. @DeveloperAPI
  95. def is_tensor_batch_type(batch: Any) -> bool:
  96. """Check if a batch matches any of the TensorBatchType variants.
  97. This function checks if the input batch is one of the following types:
  98. 1. A single torch.Tensor
  99. 2. A sequence of torch.Tensors
  100. 3. A sequence of sequences of torch.Tensors
  101. 4. A mapping (e.g., dict) of keys to torch.Tensors
  102. 5. A mapping (e.g., dict) of keys to sequences of torch.Tensors
  103. Args:
  104. batch: The input batch to check. Can be any type.
  105. Returns:
  106. bool: True if the batch matches any TensorBatchType variant, False otherwise.
  107. """
  108. return (
  109. _is_tensor(batch)
  110. or _is_tensor_sequence(batch)
  111. or _is_nested_tensor_sequence(batch)
  112. or _is_tensor_mapping(batch)
  113. or _is_tensor_sequence_mapping(batch)
  114. )
  115. TensorBatchReturnType = Union[
  116. "torch.Tensor",
  117. Tuple["torch.Tensor", ...],
  118. Dict[str, "torch.Tensor"],
  119. ]
  120. @DeveloperAPI
  121. class CollateFn(Generic[DataBatchType]):
  122. """Abstract interface for collate_fn for `iter_torch_batches`. See doc-string of
  123. `collate_fn` in `iter_torch_batches` API for more details.
  124. """
  125. @abc.abstractmethod
  126. def __call__(self, batch: DataBatchType) -> "CollatedData":
  127. """Convert a batch of data to collated format.
  128. Args:
  129. batch: The input batch to collate.
  130. Returns:
  131. The collated data in the format expected by the model.
  132. """
  133. ...
  134. @DeveloperAPI
  135. class ArrowBatchCollateFn(CollateFn["pyarrow.Table"]):
  136. """Collate function that takes pyarrow.Table as the input batch type.
  137. Arrow tables with chunked arrays can be efficiently transferred to GPUs without
  138. combining the chunks with the `arrow_batch_to_tensors` utility function.
  139. See `DefaultCollateFn` for example.
  140. """
  141. def __call__(self, batch: "pyarrow.Table") -> "CollatedData":
  142. """Convert a batch of pyarrow.Table to collated format.
  143. Args:
  144. batch: The input pyarrow.Table batch to collate.
  145. Returns:
  146. The collated data in the format expected by the model.
  147. """
  148. ...
  149. @DeveloperAPI
  150. class NumpyBatchCollateFn(CollateFn[Dict[str, np.ndarray]]):
  151. """Collate function that takes a dictionary of numpy arrays as the input batch type."""
  152. def __call__(self, batch: Dict[str, np.ndarray]) -> "CollatedData":
  153. """Convert a batch of numpy arrays to collated format.
  154. Args:
  155. batch: The input dictionary of numpy arrays batch to collate.
  156. Returns:
  157. The collated data in the format expected by the model.
  158. """
  159. ...
  160. @DeveloperAPI
  161. class PandasBatchCollateFn(CollateFn["pandas.DataFrame"]):
  162. """Collate function that takes a pandas.DataFrame as the input batch type."""
  163. def __call__(self, batch: "pandas.DataFrame") -> "CollatedData":
  164. """Convert a batch of pandas.DataFrame to collated format.
  165. Args:
  166. batch: The input pandas.DataFrame batch to collate.
  167. Returns:
  168. The collated data in the format expected by the model.
  169. """
  170. ...
  171. @DeveloperAPI
  172. class DefaultCollateFn(ArrowBatchCollateFn):
  173. """Default collate function for converting Arrow batches to PyTorch tensors."""
  174. _DEFAULT_NUM_WORKERS = env_integer(
  175. "RAY_DATA_DEFAULT_COLLATE_FN_THREADPOOL_MAX_WORKERS",
  176. 4,
  177. )
  178. def __init__(
  179. self,
  180. dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None,
  181. device: Optional["TorchDeviceType"] = None,
  182. pin_memory: bool = False,
  183. num_workers: int = _DEFAULT_NUM_WORKERS,
  184. ):
  185. """Initialize the collate function.
  186. Args:
  187. dtypes: The torch dtype(s) for the created tensor(s); if None, the dtype
  188. will be inferred from the tensor data.
  189. device: The device on which the tensor should be placed. Can be a string
  190. (e.g. "cpu", "cuda:0") or a torch.device object.
  191. pin_memory: Whether to pin the memory of the created tensors.
  192. num_workers: Number of worker threads for parallel tensor conversion.
  193. Defaults to `RAY_DATA_DEFAULT_COLLATE_FN_THREADPOOL_MAX_WORKERS`.
  194. """
  195. import torch
  196. super().__init__()
  197. self.dtypes = dtypes
  198. if isinstance(device, (str, int)):
  199. self.device = torch.device(device)
  200. else:
  201. self.device = device
  202. self.pin_memory = pin_memory
  203. self.num_workers = num_workers
  204. self._threadpool: Optional[ThreadPoolExecutor] = None
  205. def __del__(self):
  206. """Clean up threadpool on destruction."""
  207. if getattr(self, "_threadpool", None):
  208. self._threadpool.shutdown(wait=False)
  209. def __call__(
  210. self, batch: "pyarrow.Table"
  211. ) -> Union[Dict[str, "torch.Tensor"], Dict[str, List["torch.Tensor"]]]:
  212. """Convert an Arrow batch to PyTorch tensors.
  213. Args:
  214. batch: PyArrow Table to convert
  215. Returns:
  216. Dictionary mapping column names to lists of tensors
  217. """
  218. from ray.data.util.torch_utils import (
  219. arrow_batch_to_tensors,
  220. )
  221. if self.num_workers > 0 and self._threadpool is None:
  222. self._threadpool = ThreadPoolExecutor(max_workers=self.num_workers)
  223. # For GPU transfer, we can skip the combining chunked arrays. This is because
  224. # we can convert the chunked arrays to corresponding numpy format and then to
  225. # Tensors and transfer the corresponding list of Tensors to GPU directly.
  226. # However, for CPU transfer, we need to combine the chunked arrays first
  227. # before converting to numpy format and then to Tensors.
  228. combine_chunks = self.device is not None and self.device.type == "cpu"
  229. return arrow_batch_to_tensors(
  230. batch,
  231. dtypes=self.dtypes,
  232. combine_chunks=combine_chunks,
  233. pin_memory=self.pin_memory,
  234. threadpool=self._threadpool,
  235. )