import abc from concurrent.futures import ThreadPoolExecutor from typing import ( TYPE_CHECKING, Any, Dict, Generic, List, Mapping, Optional, Tuple, TypeVar, Union, ) import numpy as np from ray._private.ray_constants import env_integer from ray.util.annotations import DeveloperAPI if TYPE_CHECKING: import pandas import pyarrow import torch from ray.data.block import DataBatch from ray.data.dataset import CollatedData, TorchDeviceType DataBatchType = TypeVar("DataBatchType", bound="DataBatch") TensorSequenceType = Union[ List["torch.Tensor"], Tuple["torch.Tensor", ...], ] TensorBatchType = Union[ "torch.Tensor", TensorSequenceType, # For nested sequences of tensors, the inner sequence of tensors is combined during # GPU transfer in `move_tensors_to_device`. List[TensorSequenceType], Tuple[TensorSequenceType, ...], Mapping[str, "torch.Tensor"], # For mapping (e.g., dict) of keys to sequences of tensors, the sequence of tensors # is combined during GPU transfer in `move_tensors_to_device`. Mapping[str, TensorSequenceType], ] def _is_tensor(batch: Any) -> bool: """Check if a batch is a single torch.Tensor.""" import torch return isinstance(batch, torch.Tensor) def _is_tensor_sequence(batch: Any) -> bool: """Check if a batch is a sequence of torch.Tensors. >>> import torch >>> _is_tensor_sequence(torch.ones(1)) False >>> _is_tensor_sequence([torch.ones(1), torch.ones(1)]) True >>> _is_tensor_sequence((torch.ones(1), torch.ones(1))) True >>> _is_tensor_sequence([torch.ones(1), 1]) False """ return isinstance(batch, (list, tuple)) and all(_is_tensor(t) for t in batch) def _is_nested_tensor_sequence(batch: Any) -> bool: """Check if a batch is a sequence of sequences of torch.Tensors. Stops at one level of nesting. >>> import torch >>> _is_nested_tensor_sequence([torch.ones(1), torch.ones(1)]) False >>> _is_nested_tensor_sequence( ... ([torch.ones(1), torch.ones(1)], [torch.ones(1)]) ... ) True """ return isinstance(batch, (list, tuple)) and all( _is_tensor_sequence(t) for t in batch ) def _is_tensor_mapping(batch: Any) -> bool: """Check if a batch is a mapping of keys to torch.Tensors. >>> import torch >>> _is_tensor_mapping({"a": torch.ones(1), "b": torch.ones(1)}) True >>> _is_tensor_mapping({"a": torch.ones(1), "b": [torch.ones(1), torch.ones(1)]}) False """ return isinstance(batch, Mapping) and all(_is_tensor(v) for v in batch.values()) def _is_tensor_sequence_mapping(batch: Any) -> bool: """Check if a batch is a mapping of keys to sequences of torch.Tensors. >>> import torch >>> _is_tensor_sequence_mapping({"a": torch.ones(1), "b": torch.ones(1)}) False >>> _is_tensor_sequence_mapping( ... {"a": (torch.ones(1), torch.ones(1)), "b": [torch.ones(1), torch.ones(1)]} ... ) True """ return isinstance(batch, Mapping) and all( _is_tensor_sequence(v) for v in batch.values() ) @DeveloperAPI def is_tensor_batch_type(batch: Any) -> bool: """Check if a batch matches any of the TensorBatchType variants. This function checks if the input batch is one of the following types: 1. A single torch.Tensor 2. A sequence of torch.Tensors 3. A sequence of sequences of torch.Tensors 4. A mapping (e.g., dict) of keys to torch.Tensors 5. A mapping (e.g., dict) of keys to sequences of torch.Tensors Args: batch: The input batch to check. Can be any type. Returns: bool: True if the batch matches any TensorBatchType variant, False otherwise. """ return ( _is_tensor(batch) or _is_tensor_sequence(batch) or _is_nested_tensor_sequence(batch) or _is_tensor_mapping(batch) or _is_tensor_sequence_mapping(batch) ) TensorBatchReturnType = Union[ "torch.Tensor", Tuple["torch.Tensor", ...], Dict[str, "torch.Tensor"], ] @DeveloperAPI class CollateFn(Generic[DataBatchType]): """Abstract interface for collate_fn for `iter_torch_batches`. See doc-string of `collate_fn` in `iter_torch_batches` API for more details. """ @abc.abstractmethod def __call__(self, batch: DataBatchType) -> "CollatedData": """Convert a batch of data to collated format. Args: batch: The input batch to collate. Returns: The collated data in the format expected by the model. """ ... @DeveloperAPI class ArrowBatchCollateFn(CollateFn["pyarrow.Table"]): """Collate function that takes pyarrow.Table as the input batch type. Arrow tables with chunked arrays can be efficiently transferred to GPUs without combining the chunks with the `arrow_batch_to_tensors` utility function. See `DefaultCollateFn` for example. """ def __call__(self, batch: "pyarrow.Table") -> "CollatedData": """Convert a batch of pyarrow.Table to collated format. Args: batch: The input pyarrow.Table batch to collate. Returns: The collated data in the format expected by the model. """ ... @DeveloperAPI class NumpyBatchCollateFn(CollateFn[Dict[str, np.ndarray]]): """Collate function that takes a dictionary of numpy arrays as the input batch type.""" def __call__(self, batch: Dict[str, np.ndarray]) -> "CollatedData": """Convert a batch of numpy arrays to collated format. Args: batch: The input dictionary of numpy arrays batch to collate. Returns: The collated data in the format expected by the model. """ ... @DeveloperAPI class PandasBatchCollateFn(CollateFn["pandas.DataFrame"]): """Collate function that takes a pandas.DataFrame as the input batch type.""" def __call__(self, batch: "pandas.DataFrame") -> "CollatedData": """Convert a batch of pandas.DataFrame to collated format. Args: batch: The input pandas.DataFrame batch to collate. Returns: The collated data in the format expected by the model. """ ... @DeveloperAPI class DefaultCollateFn(ArrowBatchCollateFn): """Default collate function for converting Arrow batches to PyTorch tensors.""" _DEFAULT_NUM_WORKERS = env_integer( "RAY_DATA_DEFAULT_COLLATE_FN_THREADPOOL_MAX_WORKERS", 4, ) def __init__( self, dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, device: Optional["TorchDeviceType"] = None, pin_memory: bool = False, num_workers: int = _DEFAULT_NUM_WORKERS, ): """Initialize the collate function. Args: dtypes: The torch dtype(s) for the created tensor(s); if None, the dtype will be inferred from the tensor data. device: The device on which the tensor should be placed. Can be a string (e.g. "cpu", "cuda:0") or a torch.device object. pin_memory: Whether to pin the memory of the created tensors. num_workers: Number of worker threads for parallel tensor conversion. Defaults to `RAY_DATA_DEFAULT_COLLATE_FN_THREADPOOL_MAX_WORKERS`. """ import torch super().__init__() self.dtypes = dtypes if isinstance(device, (str, int)): self.device = torch.device(device) else: self.device = device self.pin_memory = pin_memory self.num_workers = num_workers self._threadpool: Optional[ThreadPoolExecutor] = None def __del__(self): """Clean up threadpool on destruction.""" if getattr(self, "_threadpool", None): self._threadpool.shutdown(wait=False) def __call__( self, batch: "pyarrow.Table" ) -> Union[Dict[str, "torch.Tensor"], Dict[str, List["torch.Tensor"]]]: """Convert an Arrow batch to PyTorch tensors. Args: batch: PyArrow Table to convert Returns: Dictionary mapping column names to lists of tensors """ from ray.data.util.torch_utils import ( arrow_batch_to_tensors, ) if self.num_workers > 0 and self._threadpool is None: self._threadpool = ThreadPoolExecutor(max_workers=self.num_workers) # For GPU transfer, we can skip the combining chunked arrays. This is because # we can convert the chunked arrays to corresponding numpy format and then to # Tensors and transfer the corresponding list of Tensors to GPU directly. # However, for CPU transfer, we need to combine the chunked arrays first # before converting to numpy format and then to Tensors. combine_chunks = self.device is not None and self.device.type == "cpu" return arrow_batch_to_tensors( batch, dtypes=self.dtypes, combine_chunks=combine_chunks, pin_memory=self.pin_memory, threadpool=self._threadpool, )