| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298 |
- import abc
- from concurrent.futures import ThreadPoolExecutor
- from typing import (
- TYPE_CHECKING,
- Any,
- Dict,
- Generic,
- List,
- Mapping,
- Optional,
- Tuple,
- TypeVar,
- Union,
- )
- import numpy as np
- from ray._private.ray_constants import env_integer
- from ray.util.annotations import DeveloperAPI
- if TYPE_CHECKING:
- import pandas
- import pyarrow
- import torch
- from ray.data.block import DataBatch
- from ray.data.dataset import CollatedData, TorchDeviceType
- DataBatchType = TypeVar("DataBatchType", bound="DataBatch")
- TensorSequenceType = Union[
- List["torch.Tensor"],
- Tuple["torch.Tensor", ...],
- ]
- TensorBatchType = Union[
- "torch.Tensor",
- TensorSequenceType,
- # For nested sequences of tensors, the inner sequence of tensors is combined during
- # GPU transfer in `move_tensors_to_device`.
- List[TensorSequenceType],
- Tuple[TensorSequenceType, ...],
- Mapping[str, "torch.Tensor"],
- # For mapping (e.g., dict) of keys to sequences of tensors, the sequence of tensors
- # is combined during GPU transfer in `move_tensors_to_device`.
- Mapping[str, TensorSequenceType],
- ]
- def _is_tensor(batch: Any) -> bool:
- """Check if a batch is a single torch.Tensor."""
- import torch
- return isinstance(batch, torch.Tensor)
- def _is_tensor_sequence(batch: Any) -> bool:
- """Check if a batch is a sequence of torch.Tensors.
- >>> import torch
- >>> _is_tensor_sequence(torch.ones(1))
- False
- >>> _is_tensor_sequence([torch.ones(1), torch.ones(1)])
- True
- >>> _is_tensor_sequence((torch.ones(1), torch.ones(1)))
- True
- >>> _is_tensor_sequence([torch.ones(1), 1])
- False
- """
- return isinstance(batch, (list, tuple)) and all(_is_tensor(t) for t in batch)
- def _is_nested_tensor_sequence(batch: Any) -> bool:
- """Check if a batch is a sequence of sequences of torch.Tensors.
- Stops at one level of nesting.
- >>> import torch
- >>> _is_nested_tensor_sequence([torch.ones(1), torch.ones(1)])
- False
- >>> _is_nested_tensor_sequence(
- ... ([torch.ones(1), torch.ones(1)], [torch.ones(1)])
- ... )
- True
- """
- return isinstance(batch, (list, tuple)) and all(
- _is_tensor_sequence(t) for t in batch
- )
- def _is_tensor_mapping(batch: Any) -> bool:
- """Check if a batch is a mapping of keys to torch.Tensors.
- >>> import torch
- >>> _is_tensor_mapping({"a": torch.ones(1), "b": torch.ones(1)})
- True
- >>> _is_tensor_mapping({"a": torch.ones(1), "b": [torch.ones(1), torch.ones(1)]})
- False
- """
- return isinstance(batch, Mapping) and all(_is_tensor(v) for v in batch.values())
- def _is_tensor_sequence_mapping(batch: Any) -> bool:
- """Check if a batch is a mapping of keys to sequences of torch.Tensors.
- >>> import torch
- >>> _is_tensor_sequence_mapping({"a": torch.ones(1), "b": torch.ones(1)})
- False
- >>> _is_tensor_sequence_mapping(
- ... {"a": (torch.ones(1), torch.ones(1)), "b": [torch.ones(1), torch.ones(1)]}
- ... )
- True
- """
- return isinstance(batch, Mapping) and all(
- _is_tensor_sequence(v) for v in batch.values()
- )
- @DeveloperAPI
- def is_tensor_batch_type(batch: Any) -> bool:
- """Check if a batch matches any of the TensorBatchType variants.
- This function checks if the input batch is one of the following types:
- 1. A single torch.Tensor
- 2. A sequence of torch.Tensors
- 3. A sequence of sequences of torch.Tensors
- 4. A mapping (e.g., dict) of keys to torch.Tensors
- 5. A mapping (e.g., dict) of keys to sequences of torch.Tensors
- Args:
- batch: The input batch to check. Can be any type.
- Returns:
- bool: True if the batch matches any TensorBatchType variant, False otherwise.
- """
- return (
- _is_tensor(batch)
- or _is_tensor_sequence(batch)
- or _is_nested_tensor_sequence(batch)
- or _is_tensor_mapping(batch)
- or _is_tensor_sequence_mapping(batch)
- )
- TensorBatchReturnType = Union[
- "torch.Tensor",
- Tuple["torch.Tensor", ...],
- Dict[str, "torch.Tensor"],
- ]
- @DeveloperAPI
- class CollateFn(Generic[DataBatchType]):
- """Abstract interface for collate_fn for `iter_torch_batches`. See doc-string of
- `collate_fn` in `iter_torch_batches` API for more details.
- """
- @abc.abstractmethod
- def __call__(self, batch: DataBatchType) -> "CollatedData":
- """Convert a batch of data to collated format.
- Args:
- batch: The input batch to collate.
- Returns:
- The collated data in the format expected by the model.
- """
- ...
- @DeveloperAPI
- class ArrowBatchCollateFn(CollateFn["pyarrow.Table"]):
- """Collate function that takes pyarrow.Table as the input batch type.
- Arrow tables with chunked arrays can be efficiently transferred to GPUs without
- combining the chunks with the `arrow_batch_to_tensors` utility function.
- See `DefaultCollateFn` for example.
- """
- def __call__(self, batch: "pyarrow.Table") -> "CollatedData":
- """Convert a batch of pyarrow.Table to collated format.
- Args:
- batch: The input pyarrow.Table batch to collate.
- Returns:
- The collated data in the format expected by the model.
- """
- ...
- @DeveloperAPI
- class NumpyBatchCollateFn(CollateFn[Dict[str, np.ndarray]]):
- """Collate function that takes a dictionary of numpy arrays as the input batch type."""
- def __call__(self, batch: Dict[str, np.ndarray]) -> "CollatedData":
- """Convert a batch of numpy arrays to collated format.
- Args:
- batch: The input dictionary of numpy arrays batch to collate.
- Returns:
- The collated data in the format expected by the model.
- """
- ...
- @DeveloperAPI
- class PandasBatchCollateFn(CollateFn["pandas.DataFrame"]):
- """Collate function that takes a pandas.DataFrame as the input batch type."""
- def __call__(self, batch: "pandas.DataFrame") -> "CollatedData":
- """Convert a batch of pandas.DataFrame to collated format.
- Args:
- batch: The input pandas.DataFrame batch to collate.
- Returns:
- The collated data in the format expected by the model.
- """
- ...
- @DeveloperAPI
- class DefaultCollateFn(ArrowBatchCollateFn):
- """Default collate function for converting Arrow batches to PyTorch tensors."""
- _DEFAULT_NUM_WORKERS = env_integer(
- "RAY_DATA_DEFAULT_COLLATE_FN_THREADPOOL_MAX_WORKERS",
- 4,
- )
- def __init__(
- self,
- dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None,
- device: Optional["TorchDeviceType"] = None,
- pin_memory: bool = False,
- num_workers: int = _DEFAULT_NUM_WORKERS,
- ):
- """Initialize the collate function.
- Args:
- dtypes: The torch dtype(s) for the created tensor(s); if None, the dtype
- will be inferred from the tensor data.
- device: The device on which the tensor should be placed. Can be a string
- (e.g. "cpu", "cuda:0") or a torch.device object.
- pin_memory: Whether to pin the memory of the created tensors.
- num_workers: Number of worker threads for parallel tensor conversion.
- Defaults to `RAY_DATA_DEFAULT_COLLATE_FN_THREADPOOL_MAX_WORKERS`.
- """
- import torch
- super().__init__()
- self.dtypes = dtypes
- if isinstance(device, (str, int)):
- self.device = torch.device(device)
- else:
- self.device = device
- self.pin_memory = pin_memory
- self.num_workers = num_workers
- self._threadpool: Optional[ThreadPoolExecutor] = None
- def __del__(self):
- """Clean up threadpool on destruction."""
- if getattr(self, "_threadpool", None):
- self._threadpool.shutdown(wait=False)
- def __call__(
- self, batch: "pyarrow.Table"
- ) -> Union[Dict[str, "torch.Tensor"], Dict[str, List["torch.Tensor"]]]:
- """Convert an Arrow batch to PyTorch tensors.
- Args:
- batch: PyArrow Table to convert
- Returns:
- Dictionary mapping column names to lists of tensors
- """
- from ray.data.util.torch_utils import (
- arrow_batch_to_tensors,
- )
- if self.num_workers > 0 and self._threadpool is None:
- self._threadpool = ThreadPoolExecutor(max_workers=self.num_workers)
- # For GPU transfer, we can skip the combining chunked arrays. This is because
- # we can convert the chunked arrays to corresponding numpy format and then to
- # Tensors and transfer the corresponding list of Tensors to GPU directly.
- # However, for CPU transfer, we need to combine the chunked arrays first
- # before converting to numpy format and then to Tensors.
- combine_chunks = self.device is not None and self.device.type == "cpu"
- return arrow_batch_to_tensors(
- batch,
- dtypes=self.dtypes,
- combine_chunks=combine_chunks,
- pin_memory=self.pin_memory,
- threadpool=self._threadpool,
- )
|