iterator.py 47 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067
  1. import abc
  2. import time
  3. import warnings
  4. from typing import (
  5. TYPE_CHECKING,
  6. Any,
  7. Callable,
  8. Dict,
  9. Iterable,
  10. Iterator,
  11. List,
  12. Literal,
  13. Optional,
  14. Tuple,
  15. TypeVar,
  16. Union,
  17. )
  18. import numpy as np
  19. from ray.data._internal.block_batching.iter_batches import BatchIterator
  20. from ray.data._internal.execution.interfaces import RefBundle
  21. from ray.data._internal.logical.interfaces import LogicalPlan
  22. from ray.data._internal.logical.operators import InputData
  23. from ray.data._internal.plan import ExecutionPlan
  24. from ray.data._internal.stats import DatasetStats
  25. from ray.data.block import BlockAccessor, DataBatch, _apply_batch_format
  26. from ray.data.collate_fn import (
  27. ArrowBatchCollateFn,
  28. CollateFn,
  29. DefaultCollateFn,
  30. NumpyBatchCollateFn,
  31. PandasBatchCollateFn,
  32. TensorBatchReturnType,
  33. TensorBatchType,
  34. is_tensor_batch_type,
  35. )
  36. from ray.data.context import DataContext
  37. from ray.util.annotations import PublicAPI, RayDeprecationWarning
  38. if TYPE_CHECKING:
  39. import tensorflow as tf
  40. import torch
  41. from ray.data._internal.execution.streaming_executor import StreamingExecutor
  42. from ray.data.dataset import (
  43. CollatedData,
  44. MaterializedDataset,
  45. Schema,
  46. TensorFlowTensorBatchType,
  47. TorchBatchType,
  48. TorchDeviceType,
  49. )
  50. T = TypeVar("T")
  51. class _IterableFromIterator(Iterable[T]):
  52. def __init__(self, iterator_gen: Callable[[], Iterator[T]]):
  53. """Constructs an Iterable from an iterator generator.
  54. Args:
  55. iterator_gen: A function that returns an iterator each time it
  56. is called. For example, this can be a generator function.
  57. """
  58. self.iterator_gen = iterator_gen
  59. def __iter__(self):
  60. return self.iterator_gen()
  61. @PublicAPI
  62. class DataIterator(abc.ABC):
  63. """An iterator for reading records from a :class:`~Dataset`.
  64. For Datasets, each iteration call represents a complete read of all items in the
  65. Dataset.
  66. If using Ray Train, each trainer actor should get its own iterator by calling
  67. :meth:`ray.train.get_dataset_shard("train")
  68. <ray.train.get_dataset_shard>`.
  69. Examples:
  70. >>> import ray
  71. >>> ds = ray.data.range(5)
  72. >>> ds
  73. shape: (5, 1)
  74. ╭───────╮
  75. │ id │
  76. │ --- │
  77. │ int64 │
  78. ╰───────╯
  79. (Dataset isn't materialized)
  80. >>> ds.iterator()
  81. DataIterator(shape: (5, 1)
  82. ╭───────╮
  83. │ id │
  84. │ --- │
  85. │ int64 │
  86. ╰───────╯
  87. (Dataset isn't materialized))
  88. """
  89. @abc.abstractmethod
  90. def _to_ref_bundle_iterator(
  91. self,
  92. ) -> Tuple[
  93. Iterator[RefBundle], Optional[DatasetStats], bool, Optional["StreamingExecutor"]
  94. ]:
  95. """Returns the iterator to use for `iter_batches`.
  96. Returns:
  97. A tuple containing:
  98. - An iterator over RefBundles.
  99. - A DatasetStats object used for recording stats during iteration.
  100. - A boolean indicating if the blocks can be safely cleared after use.
  101. - An optional executor (StreamingExecutor) for reporting prefetched bytes.
  102. """
  103. ...
  104. @PublicAPI
  105. def iter_batches(
  106. self,
  107. *,
  108. prefetch_batches: int = 1,
  109. batch_size: int = 256,
  110. batch_format: Optional[str] = "default",
  111. drop_last: bool = False,
  112. local_shuffle_buffer_size: Optional[int] = None,
  113. local_shuffle_seed: Optional[int] = None,
  114. ) -> Iterable[DataBatch]:
  115. """Return a batched iterable over the dataset.
  116. Examples:
  117. >>> import ray
  118. >>> for batch in ray.data.range(
  119. ... 1000000
  120. ... ).iterator().iter_batches(): # doctest: +SKIP
  121. ... print(batch) # doctest: +SKIP
  122. Time complexity: O(1)
  123. Args:
  124. prefetch_batches: The number of batches to fetch ahead of the current batch
  125. to fetch. If set to greater than 0, a separate threadpool will be used
  126. to fetch the objects to the local node, format the batches, and apply
  127. the collate_fn. Defaults to 1.
  128. batch_size: The number of rows in each batch, or None to use entire blocks
  129. as batches (blocks may contain different number of rows).
  130. The final batch may include fewer than ``batch_size`` rows if
  131. ``drop_last`` is ``False``. Defaults to 256.
  132. batch_format: Specify ``"default"`` to use the default block format
  133. (NumPy), ``"pandas"`` to select ``pandas.DataFrame``, "pyarrow" to
  134. select ``pyarrow.Table``, or ``"numpy"`` to select
  135. ``Dict[str, numpy.ndarray]``, or None to return the underlying block
  136. exactly as is with no additional formatting.
  137. drop_last: Whether to drop the last batch if it's incomplete.
  138. local_shuffle_buffer_size: If non-None, the data will be randomly shuffled
  139. using a local in-memory shuffle buffer, and this value will serve as the
  140. minimum number of rows that must be in the local in-memory shuffle
  141. buffer in order to yield a batch. When there are no more rows to add to
  142. the buffer, the remaining rows in the buffer will be drained.
  143. local_shuffle_seed: The seed to use for the local random shuffle.
  144. Returns:
  145. An iterable over record batches.
  146. """
  147. return self._iter_batches(
  148. prefetch_batches=prefetch_batches,
  149. batch_size=batch_size,
  150. batch_format=batch_format,
  151. drop_last=drop_last,
  152. local_shuffle_buffer_size=local_shuffle_buffer_size,
  153. local_shuffle_seed=local_shuffle_seed,
  154. )
  155. def _create_batch_iterator(
  156. self,
  157. ref_bundles_iter: Iterator[RefBundle],
  158. prefetch_bytes_callback: Optional[Callable[[int], None]] = None,
  159. **kwargs,
  160. ) -> BatchIterator:
  161. return BatchIterator(
  162. ref_bundles_iter,
  163. prefetch_bytes_callback=prefetch_bytes_callback,
  164. **kwargs,
  165. )
  166. def _iter_batches(
  167. self,
  168. *,
  169. prefetch_batches: int = 1,
  170. batch_size: int = 256,
  171. batch_format: Optional[str] = "default",
  172. drop_last: bool = False,
  173. local_shuffle_buffer_size: Optional[int] = None,
  174. local_shuffle_seed: Optional[int] = None,
  175. _collate_fn: Optional[Callable[[DataBatch], "CollatedData"]] = None,
  176. _finalize_fn: Optional[Callable[[Any], Any]] = None,
  177. ) -> Iterable[DataBatch]:
  178. batch_format = _apply_batch_format(batch_format)
  179. def _create_iterator() -> Iterator[DataBatch]:
  180. time_start = time.perf_counter()
  181. # Iterate through the dataset from the start each time
  182. # _iterator_gen is called.
  183. # This allows multiple iterations of the dataset without
  184. # needing to explicitly call `iter_batches()` multiple times.
  185. (
  186. ref_bundles_iterator,
  187. stats,
  188. blocks_owned_by_consumer,
  189. executor,
  190. ) = self._to_ref_bundle_iterator()
  191. dataset_tag = self._get_dataset_tag()
  192. # Create a callback to report prefetched bytes to the executor's
  193. # resource manager.
  194. def make_prefetch_callback(exec):
  195. def callback(num_bytes: int) -> None:
  196. exec.set_external_consumer_bytes(num_bytes)
  197. return callback
  198. prefetch_bytes_callback = (
  199. make_prefetch_callback(executor) if executor is not None else None
  200. )
  201. batch_iterator = self._create_batch_iterator(
  202. ref_bundles_iterator,
  203. stats=stats,
  204. dataset_tag=dataset_tag,
  205. clear_block_after_read=blocks_owned_by_consumer,
  206. batch_size=batch_size,
  207. batch_format=batch_format,
  208. drop_last=drop_last,
  209. collate_fn=_collate_fn,
  210. finalize_fn=_finalize_fn,
  211. shuffle_buffer_min_size=local_shuffle_buffer_size,
  212. shuffle_seed=local_shuffle_seed,
  213. prefetch_batches=prefetch_batches,
  214. prefetch_bytes_callback=prefetch_bytes_callback,
  215. )
  216. if stats:
  217. stats.iter_initialize_s.add(time.perf_counter() - time_start)
  218. yield from batch_iterator
  219. if stats:
  220. stats.iter_total_s.add(time.perf_counter() - time_start)
  221. return _IterableFromIterator(_create_iterator)
  222. def _get_dataset_tag(self) -> str:
  223. return "unknown_dataset"
  224. @PublicAPI
  225. def iter_rows(self) -> Iterable[Dict[str, Any]]:
  226. """Return a local row iterable over the dataset.
  227. If the dataset is a tabular dataset (Arrow/Pandas blocks), dicts
  228. are yielded for each row by the iterator. If the dataset is not tabular,
  229. the raw row is yielded.
  230. Examples:
  231. >>> import ray
  232. >>> dataset = ray.data.range(10)
  233. >>> next(iter(dataset.iterator().iter_rows()))
  234. {'id': 0}
  235. Time complexity: O(1)
  236. Returns:
  237. An iterable over rows of the dataset.
  238. """
  239. batch_iterable = self._iter_batches(
  240. batch_size=None, batch_format=None, prefetch_batches=1
  241. )
  242. def _wrapped_iterator():
  243. for batch in batch_iterable:
  244. batch = BlockAccessor.for_block(BlockAccessor.batch_to_block(batch))
  245. for row in batch.iter_rows(public_row_format=True):
  246. yield row
  247. return _IterableFromIterator(_wrapped_iterator)
  248. @abc.abstractmethod
  249. @PublicAPI
  250. def stats(self) -> str:
  251. """Returns a string containing execution timing information."""
  252. ...
  253. @abc.abstractmethod
  254. def schema(self) -> "Schema":
  255. """Return the schema of the dataset iterated over."""
  256. ...
  257. @abc.abstractmethod
  258. def get_context(self) -> DataContext:
  259. ...
  260. @PublicAPI
  261. def iter_torch_batches(
  262. self,
  263. *,
  264. prefetch_batches: int = 1,
  265. batch_size: Optional[int] = 256,
  266. dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None,
  267. device: Union["TorchDeviceType", Literal["auto"]] = "auto",
  268. collate_fn: Optional[
  269. Union[Callable[[Dict[str, np.ndarray]], "CollatedData"], CollateFn]
  270. ] = None,
  271. drop_last: bool = False,
  272. local_shuffle_buffer_size: Optional[int] = None,
  273. local_shuffle_seed: Optional[int] = None,
  274. pin_memory: bool = False,
  275. ) -> Iterable["TorchBatchType"]:
  276. """Return a batched iterable of Torch Tensors over the dataset.
  277. This iterable yields a dictionary of column-tensors. If you are looking for
  278. more flexibility in the tensor conversion (e.g. casting dtypes) or the batch
  279. format, try using :meth:`~ray.data.DataIterator.iter_batches` directly.
  280. Examples:
  281. >>> import ray
  282. >>> for batch in ray.data.range(
  283. ... 12,
  284. ... ).iterator().iter_torch_batches(batch_size=4):
  285. ... print(batch)
  286. {'id': tensor([0, 1, 2, 3])}
  287. {'id': tensor([4, 5, 6, 7])}
  288. {'id': tensor([ 8, 9, 10, 11])}
  289. Use the ``ArrowBatchCollateFn`` to customize how the tensor batch is created
  290. from an Arrow batch.
  291. >>> import pyarrow as pa
  292. >>> import torch
  293. >>> import ray
  294. >>> from ray.data.collate_fn import ArrowBatchCollateFn
  295. >>> class CustomArrowBatchCollateFn(ArrowBatchCollateFn):
  296. ... def __call__(self, batch: pa.Table) -> torch.Tensor:
  297. ... return torch.as_tensor(batch["col_1"].to_numpy() + 5)
  298. >>> iterator = ray.data.from_items([
  299. ... {"col_1": 1, "col_2": 2},
  300. ... {"col_1": 3, "col_2": 4}]).iterator()
  301. >>> for batch in iterator.iter_torch_batches(collate_fn=CustomArrowBatchCollateFn()):
  302. ... print(batch)
  303. tensor([6, 8])
  304. Use the ``NumpyBatchCollateFn`` to customize how the tensor batch is created
  305. from a Numpy batch.
  306. >>> from typing import Dict
  307. >>> import numpy as np
  308. >>> import torch
  309. >>> import ray
  310. >>> from ray.data.collate_fn import NumpyBatchCollateFn
  311. >>> class CustomNumpyBatchCollateFn(NumpyBatchCollateFn):
  312. ... def __call__(self, batch: Dict[str, np.ndarray]) -> torch.Tensor:
  313. ... return torch.as_tensor(batch["col_1"] + 5)
  314. >>> iterator = ray.data.from_items([
  315. ... {"col_1": 1, "col_2": 2},
  316. ... {"col_1": 3, "col_2": 4}]).iterator()
  317. >>> for batch in iterator.iter_torch_batches(collate_fn=CustomNumpyBatchCollateFn()):
  318. ... print(batch)
  319. tensor([6, 8])
  320. Use the ``PandasBatchCollateFn`` to customize how the tensor batch is created
  321. from a Pandas batch.
  322. >>> import pandas as pd
  323. >>> import torch
  324. >>> import ray
  325. >>> from ray.data.collate_fn import PandasBatchCollateFn
  326. >>> class CustomPandasBatchCollateFn(PandasBatchCollateFn):
  327. ... def __call__(self, batch: pd.DataFrame) -> torch.Tensor:
  328. ... return torch.as_tensor(batch["col_1"].to_numpy() + 5)
  329. >>> iterator = ray.data.from_items([
  330. ... {"col_1": 1, "col_2": 2},
  331. ... {"col_1": 3, "col_2": 4}]).iterator()
  332. >>> for batch in iterator.iter_torch_batches(collate_fn=CustomPandasBatchCollateFn()):
  333. ... print(batch)
  334. tensor([6, 8])
  335. Time complexity: O(1)
  336. Args:
  337. prefetch_batches: The number of batches to fetch ahead of the current batch
  338. to fetch. If set to greater than 0, a separate threadpool will be used
  339. to fetch the objects to the local node, format the batches, and apply
  340. the collate_fn. Defaults to 1.
  341. batch_size: The number of rows in each batch, or None to use entire blocks
  342. as batches (blocks may contain different number of rows).
  343. The final batch may include fewer than ``batch_size`` rows if
  344. ``drop_last`` is ``False``. Defaults to 256.
  345. dtypes: The Torch dtype(s) for the created tensor(s); if None, the dtype
  346. will be inferred from the tensor data. You can't use this parameter
  347. with ``collate_fn``.
  348. device: The device on which the tensor should be placed. Defaults to
  349. "auto" which moves the tensors to the appropriate device when the
  350. Dataset is passed to Ray Train and ``collate_fn`` is not provided.
  351. Otherwise, defaults to CPU. You can't use this parameter with
  352. ``collate_fn``.
  353. collate_fn: [Alpha] A function to customize how data batches are collated
  354. before being passed to the model. This is useful for last-mile data
  355. formatting such as padding, masking, or packaging tensors into custom
  356. data structures. If not provided, `iter_torch_batches` automatically
  357. converts batches to `torch.Tensor`s and moves them to the device
  358. assigned to the current worker. The input to `collate_fn` may be:
  359. 1. pyarrow.Table, where you should provide a callable class that
  360. subclasses `ArrowBatchCollateFn` (recommended for best performance).
  361. Note that you should use util function `arrow_batch_to_tensors` to
  362. convert the pyarrow.Table to a dictionary of non-contiguous tensor
  363. batches.
  364. 2. Dict[str, np.ndarray], where you should provide a callable class that
  365. subclasses `NumpyBatchCollateFn`
  366. 3. pd.DataFrame, where you should provide a callable class that
  367. subclasses `PandasBatchCollateFn`
  368. The output can be any type. If the output is a `TensorBatchType`, it will be
  369. automatically moved to the current worker's device. For other types,
  370. you must handle device transfer manually in your training loop.
  371. Note: This function is called in a multi-threaded context; avoid using
  372. thread-unsafe code.
  373. drop_last: Whether to drop the last batch if it's incomplete.
  374. local_shuffle_buffer_size: If non-None, the data will be randomly shuffled
  375. using a local in-memory shuffle buffer, and this value will serve as the
  376. minimum number of rows that must be in the local in-memory shuffle
  377. buffer in order to yield a batch. When there are no more rows to add to
  378. the buffer, the remaining rows in the buffer will be drained. This
  379. buffer size must be greater than or equal to ``batch_size``, and
  380. therefore ``batch_size`` must also be specified when using local
  381. shuffling.
  382. local_shuffle_seed: The seed to use for the local random shuffle.
  383. pin_memory: [Alpha] If True, copies the tensor to pinned memory. Note that
  384. `pin_memory` is only supported when using `DefaultCollateFn`.
  385. Returns:
  386. An iterable over Torch Tensor batches.
  387. """
  388. from ray.train.torch import get_device
  389. from ray.train.utils import _in_ray_train_worker
  390. if collate_fn is not None and (dtypes is not None or device != "auto"):
  391. raise ValueError(
  392. "collate_fn cannot be used with dtypes and device."
  393. "You should manually move the output Torch tensors to the"
  394. "desired dtype and device outside of collate_fn."
  395. )
  396. if pin_memory and collate_fn is not None:
  397. raise ValueError(
  398. "pin_memory is only supported when using `DefaultCollateFn`."
  399. )
  400. if device == "auto":
  401. # Use the appropriate device for Ray Train, or falls back to CPU if
  402. # Ray Train is not being used.
  403. device = get_device() if _in_ray_train_worker() else "cpu"
  404. from ray.data.util.torch_utils import (
  405. move_tensors_to_device,
  406. )
  407. # The default finalize_fn handles the host to device data transfer.
  408. # This is executed in a 1-thread pool separately from collate_fn
  409. # to allow independent parallelism of these steps.
  410. def default_finalize_fn(
  411. batch: TensorBatchType,
  412. ) -> Union[TensorBatchReturnType, Any]:
  413. """Default finalize function for moving PyTorch tensors to device. If
  414. batch is of type `TensorBatchType`, it will be automatically moved to the
  415. current worker's device. For other types, you must handle device transfer
  416. manually in your training loop.
  417. Args:
  418. batch: Input batch to move to device.
  419. Returns:
  420. Batch with tensors moved to the target device.
  421. - If input is TensorBatchType, returns tensors moved to device
  422. - Otherwise returns the same type as input without moving tensors
  423. to device.
  424. """
  425. if is_tensor_batch_type(batch):
  426. return move_tensors_to_device(batch, device=device)
  427. else:
  428. return batch
  429. if collate_fn is None:
  430. # The default collate_fn handles formatting and Tensor creation.
  431. # Here, we defer host to device data transfer to the subsequent
  432. # finalize_fn.
  433. collate_fn = DefaultCollateFn(
  434. dtypes=dtypes,
  435. device=device,
  436. pin_memory=pin_memory,
  437. )
  438. batch_format = "pyarrow"
  439. elif isinstance(collate_fn, ArrowBatchCollateFn):
  440. # The ArrowBatchCollateFn handles formatting and Tensor creation.
  441. # Here, we defer host to device data transfer to the subsequent
  442. # finalize_fn.
  443. batch_format = "pyarrow"
  444. elif isinstance(collate_fn, NumpyBatchCollateFn):
  445. batch_format = "numpy"
  446. elif isinstance(collate_fn, PandasBatchCollateFn):
  447. batch_format = "pandas"
  448. elif callable(collate_fn):
  449. batch_format = "numpy"
  450. warnings.warn(
  451. "Passing a function to `iter_torch_batches(collate_fn)` is "
  452. "deprecated in Ray 2.47. Please switch to using a callable class that "
  453. "inherits from `ArrowBatchCollateFn`, `NumpyBatchCollateFn`, or "
  454. "`PandasBatchCollateFn`.",
  455. RayDeprecationWarning,
  456. )
  457. else:
  458. raise ValueError(f"Unsupported collate function: {type(collate_fn)}")
  459. return self._iter_batches(
  460. prefetch_batches=prefetch_batches,
  461. batch_size=batch_size,
  462. batch_format=batch_format,
  463. drop_last=drop_last,
  464. local_shuffle_buffer_size=local_shuffle_buffer_size,
  465. local_shuffle_seed=local_shuffle_seed,
  466. _collate_fn=collate_fn,
  467. _finalize_fn=default_finalize_fn,
  468. )
  469. def iter_tf_batches(
  470. self,
  471. *,
  472. prefetch_batches: int = 1,
  473. batch_size: Optional[int] = 256,
  474. dtypes: Optional[Union["tf.dtypes.DType", Dict[str, "tf.dtypes.DType"]]] = None,
  475. drop_last: bool = False,
  476. local_shuffle_buffer_size: Optional[int] = None,
  477. local_shuffle_seed: Optional[int] = None,
  478. ) -> Iterable["TensorFlowTensorBatchType"]:
  479. """Return a batched iterable of TensorFlow Tensors over the dataset.
  480. This iterable will yield single-tensor batches of the underlying dataset
  481. consists of a single column; otherwise, it will yield a dictionary of
  482. column-tensors.
  483. .. tip::
  484. If you don't need the additional flexibility provided by this method,
  485. consider using :meth:`~ray.data.Dataset.to_tf` instead. It's easier
  486. to use.
  487. Examples:
  488. >>> import ray
  489. >>> for batch in ray.data.range( # doctest: +SKIP
  490. ... 12,
  491. ... ).iter_tf_batches(batch_size=4):
  492. ... print(batch.shape) # doctest: +SKIP
  493. (4, 1)
  494. (4, 1)
  495. (4, 1)
  496. Time complexity: O(1)
  497. Args:
  498. prefetch_batches: The number of batches to fetch ahead of the current batch
  499. to fetch. If set to greater than 0, a separate threadpool will be used
  500. to fetch the objects to the local node, format the batches, and apply
  501. the collate_fn. Defaults to 1.
  502. batch_size: The number of rows in each batch, or None to use entire blocks
  503. as batches (blocks may contain different number of rows).
  504. The final batch may include fewer than ``batch_size`` rows if
  505. ``drop_last`` is ``False``. Defaults to 256.
  506. dtypes: The TensorFlow dtype(s) for the created tensor(s); if None, the
  507. dtype will be inferred from the tensor data.
  508. drop_last: Whether to drop the last batch if it's incomplete.
  509. local_shuffle_buffer_size: If non-None, the data will be randomly shuffled
  510. using a local in-memory shuffle buffer, and this value will serve as the
  511. minimum number of rows that must be in the local in-memory shuffle
  512. buffer in order to yield a batch. When there are no more rows to add to
  513. the buffer, the remaining rows in the buffer will be drained. This
  514. buffer size must be greater than or equal to ``batch_size``, and
  515. therefore ``batch_size`` must also be specified when using local
  516. shuffling.
  517. local_shuffle_seed: The seed to use for the local random shuffle.
  518. Returns:
  519. An iterator over TensorFlow Tensor batches.
  520. """
  521. from ray.data._internal.utils.tensorflow_utils import (
  522. convert_ndarray_batch_to_tf_tensor_batch,
  523. )
  524. batch_iterable = self._iter_batches(
  525. prefetch_batches=prefetch_batches,
  526. batch_size=batch_size,
  527. drop_last=drop_last,
  528. local_shuffle_buffer_size=local_shuffle_buffer_size,
  529. local_shuffle_seed=local_shuffle_seed,
  530. )
  531. mapped_iterable = map(
  532. lambda batch: convert_ndarray_batch_to_tf_tensor_batch(
  533. batch, dtypes=dtypes
  534. ),
  535. batch_iterable,
  536. )
  537. return mapped_iterable
  538. def to_torch(
  539. self,
  540. *,
  541. label_column: Optional[str] = None,
  542. feature_columns: Optional[
  543. Union[List[str], List[List[str]], Dict[str, List[str]]]
  544. ] = None,
  545. label_column_dtype: Optional["torch.dtype"] = None,
  546. feature_column_dtypes: Optional[
  547. Union["torch.dtype", List["torch.dtype"], Dict[str, "torch.dtype"]]
  548. ] = None,
  549. batch_size: int = 1,
  550. prefetch_batches: int = 1,
  551. drop_last: bool = False,
  552. local_shuffle_buffer_size: Optional[int] = None,
  553. local_shuffle_seed: Optional[int] = None,
  554. unsqueeze_label_tensor: bool = True,
  555. unsqueeze_feature_tensors: bool = True,
  556. ) -> "torch.utils.data.IterableDataset":
  557. """Return a Torch IterableDataset over this dataset.
  558. This is only supported for datasets convertible to Arrow records.
  559. It is recommended to use the returned ``IterableDataset`` directly
  560. instead of passing it into a torch ``DataLoader``.
  561. Each element in IterableDataset will be a tuple consisting of 2
  562. elements. The first item contains the feature tensor(s), and the
  563. second item is the label tensor. Those can take on different
  564. forms, depending on the specified arguments.
  565. For the features tensor (N is the ``batch_size`` and n, m, k
  566. are the number of features per tensor):
  567. * If ``feature_columns`` is a ``List[str]``, the features will be
  568. a tensor of shape (N, n), with columns corresponding to
  569. ``feature_columns``
  570. * If ``feature_columns`` is a ``List[List[str]]``, the features will be
  571. a list of tensors of shape [(N, m),...,(N, k)], with columns of each
  572. tensor corresponding to the elements of ``feature_columns``
  573. * If ``feature_columns`` is a ``Dict[str, List[str]]``, the features
  574. will be a dict of key-tensor pairs of shape
  575. {key1: (N, m),..., keyN: (N, k)}, with columns of each
  576. tensor corresponding to the value of ``feature_columns`` under the
  577. key.
  578. If ``unsqueeze_label_tensor=True`` (default), the label tensor will be
  579. of shape (N, 1). Otherwise, it will be of shape (N,).
  580. If ``label_column`` is specified as ``None``, then no column from the
  581. ``Dataset`` will be treated as the label, and the output label tensor
  582. will be ``None``.
  583. Note that you probably want to call ``.split()`` on this dataset if
  584. there are to be multiple Torch workers consuming the data.
  585. Time complexity: O(1)
  586. Args:
  587. label_column: The name of the column used as the
  588. label (second element of the output list). Can be None for
  589. prediction, in which case the second element of returned
  590. tuple will also be None.
  591. feature_columns: The names of the columns
  592. to use as the features. Can be a list of lists or
  593. a dict of string-list pairs for multi-tensor output.
  594. If None, then use all columns except the label column as
  595. the features.
  596. label_column_dtype: The torch dtype to
  597. use for the label column. If None, then automatically infer
  598. the dtype.
  599. feature_column_dtypes: The dtypes to use for the feature
  600. tensors. This should match the format of ``feature_columns``,
  601. or be a single dtype, in which case it will be applied to
  602. all tensors. If None, then automatically infer the dtype.
  603. batch_size: How many samples per batch to yield at a time.
  604. Defaults to 1.
  605. prefetch_batches: The number of batches to fetch ahead of the current batch
  606. to fetch. If set to greater than 0, a separate threadpool will be used
  607. to fetch the objects to the local node, format the batches, and apply
  608. the collate_fn. Defaults to 1.
  609. drop_last: Set to True to drop the last incomplete batch,
  610. if the dataset size is not divisible by the batch size. If
  611. False and the size of dataset is not divisible by the batch
  612. size, then the last batch will be smaller. Defaults to False.
  613. local_shuffle_buffer_size: If non-None, the data will be randomly shuffled
  614. using a local in-memory shuffle buffer, and this value will serve as the
  615. minimum number of rows that must be in the local in-memory shuffle
  616. buffer in order to yield a batch. When there are no more rows to add to
  617. the buffer, the remaining rows in the buffer will be drained. This
  618. buffer size must be greater than or equal to ``batch_size``, and
  619. therefore ``batch_size`` must also be specified when using local
  620. shuffling.
  621. local_shuffle_seed: The seed to use for the local random shuffle.
  622. unsqueeze_label_tensor: If set to True, the label tensor
  623. will be unsqueezed (reshaped to (N, 1)). Otherwise, it will
  624. be left as is, that is (N, ). In general, regression loss
  625. functions expect an unsqueezed tensor, while classification
  626. loss functions expect a squeezed one. Defaults to True.
  627. unsqueeze_feature_tensors: If set to True, the features tensors
  628. will be unsqueezed (reshaped to (N, 1)) before being concatenated into
  629. the final features tensor. Otherwise, they will be left as is, that is
  630. (N, ). Defaults to True.
  631. Returns:
  632. A torch IterableDataset.
  633. """
  634. import torch
  635. from ray.data._internal.torch_iterable_dataset import TorchIterableDataset
  636. from ray.data.util.torch_utils import convert_pandas_to_torch_tensor
  637. # If an empty collection is passed in, treat it the same as None
  638. if not feature_columns:
  639. feature_columns = None
  640. if feature_column_dtypes and not isinstance(feature_column_dtypes, torch.dtype):
  641. if isinstance(feature_columns, dict):
  642. if not isinstance(feature_column_dtypes, dict):
  643. raise TypeError(
  644. "If `feature_columns` is a dict, "
  645. "`feature_column_dtypes` must be None, `torch.dtype`,"
  646. f" or dict, got {type(feature_column_dtypes)}."
  647. )
  648. if set(feature_columns) != set(feature_column_dtypes):
  649. raise ValueError(
  650. "`feature_columns` and `feature_column_dtypes` "
  651. "must have the same keys."
  652. )
  653. if any(not subcolumns for subcolumns in feature_columns.values()):
  654. raise ValueError("column list may not be empty")
  655. elif isinstance(feature_columns[0], (list, tuple)):
  656. if not isinstance(feature_column_dtypes, (list, tuple)):
  657. raise TypeError(
  658. "If `feature_columns` is a list of lists, "
  659. "`feature_column_dtypes` must be None, `torch.dtype`,"
  660. f" or a sequence, got {type(feature_column_dtypes)}."
  661. )
  662. if len(feature_columns) != len(feature_column_dtypes):
  663. raise ValueError(
  664. "`feature_columns` and `feature_column_dtypes` "
  665. "must have the same length."
  666. )
  667. if any(not subcolumns for subcolumns in feature_columns):
  668. raise ValueError("column list may not be empty")
  669. def make_generator():
  670. for batch in self._iter_batches(
  671. batch_size=batch_size,
  672. batch_format="pandas",
  673. prefetch_batches=prefetch_batches,
  674. drop_last=drop_last,
  675. local_shuffle_buffer_size=local_shuffle_buffer_size,
  676. local_shuffle_seed=local_shuffle_seed,
  677. ):
  678. if label_column:
  679. label_tensor = convert_pandas_to_torch_tensor(
  680. batch,
  681. [label_column],
  682. label_column_dtype,
  683. unsqueeze=unsqueeze_label_tensor,
  684. )
  685. batch.pop(label_column)
  686. else:
  687. label_tensor = None
  688. if isinstance(feature_columns, dict):
  689. features_tensor = {
  690. key: convert_pandas_to_torch_tensor(
  691. batch,
  692. feature_columns[key],
  693. (
  694. feature_column_dtypes[key]
  695. if isinstance(feature_column_dtypes, dict)
  696. else feature_column_dtypes
  697. ),
  698. unsqueeze=unsqueeze_feature_tensors,
  699. )
  700. for key in feature_columns
  701. }
  702. else:
  703. features_tensor = convert_pandas_to_torch_tensor(
  704. batch,
  705. columns=feature_columns,
  706. column_dtypes=feature_column_dtypes,
  707. unsqueeze=unsqueeze_feature_tensors,
  708. )
  709. yield (features_tensor, label_tensor)
  710. return TorchIterableDataset(make_generator)
  711. @PublicAPI
  712. def to_tf(
  713. self,
  714. feature_columns: Union[str, List[str]],
  715. label_columns: Union[str, List[str]],
  716. *,
  717. additional_columns: Union[Optional[str], Optional[List[str]]] = None,
  718. prefetch_batches: int = 1,
  719. batch_size: int = 1,
  720. drop_last: bool = False,
  721. local_shuffle_buffer_size: Optional[int] = None,
  722. local_shuffle_seed: Optional[int] = None,
  723. feature_type_spec: Union["tf.TypeSpec", Dict[str, "tf.TypeSpec"]] = None,
  724. label_type_spec: Union["tf.TypeSpec", Dict[str, "tf.TypeSpec"]] = None,
  725. additional_type_spec: Union[
  726. Optional["tf.TypeSpec"], Optional[Dict[str, "tf.TypeSpec"]]
  727. ] = None,
  728. ) -> "tf.data.Dataset":
  729. """Return a TF Dataset over this dataset.
  730. .. warning::
  731. If your dataset contains ragged tensors, this method errors. To prevent
  732. errors, :ref:`resize your tensors <transforming_tensors>`.
  733. Examples:
  734. >>> import ray
  735. >>> ds = ray.data.read_csv(
  736. ... "s3://anonymous@air-example-data/iris.csv"
  737. ... )
  738. >>> it = ds.iterator(); it
  739. DataIterator(Dataset(num_rows=?, schema=Unknown schema))
  740. If your model accepts a single tensor as input, specify a single feature column.
  741. >>> it.to_tf(feature_columns="sepal length (cm)", label_columns="target")
  742. <_OptionsDataset element_spec=(TensorSpec(shape=(None,), dtype=tf.float64, name='sepal length (cm)'), TensorSpec(shape=(None,), dtype=tf.int64, name='target'))>
  743. If your model accepts a dictionary as input, specify a list of feature columns.
  744. >>> it.to_tf(["sepal length (cm)", "sepal width (cm)"], "target")
  745. <_OptionsDataset element_spec=({'sepal length (cm)': TensorSpec(shape=(None,), dtype=tf.float64, name='sepal length (cm)'), 'sepal width (cm)': TensorSpec(shape=(None,), dtype=tf.float64, name='sepal width (cm)')}, TensorSpec(shape=(None,), dtype=tf.int64, name='target'))>
  746. If your dataset contains multiple features but your model accepts a single
  747. tensor as input, combine features with
  748. :class:`~ray.data.preprocessors.Concatenator`.
  749. >>> from ray.data.preprocessors import Concatenator
  750. >>> columns_to_concat = ["sepal length (cm)", "sepal width (cm)", "petal length (cm)", "petal width (cm)"]
  751. >>> preprocessor = Concatenator(columns=columns_to_concat, output_column_name="features")
  752. >>> it = preprocessor.transform(ds).iterator()
  753. >>> it.to_tf("features", "target")
  754. <_OptionsDataset element_spec=(TensorSpec(shape=(None, 4), dtype=tf.float64, name='features'), TensorSpec(shape=(None,), dtype=tf.int64, name='target'))>
  755. If your model accepts different types, shapes, or names of tensors as input, specify the type spec.
  756. If type specs are not specified, they are automatically inferred from the schema of the iterator.
  757. >>> import tensorflow as tf
  758. >>> it.to_tf(
  759. ... feature_columns="features",
  760. ... label_columns="target",
  761. ... feature_type_spec=tf.TensorSpec(shape=(None, 4), dtype=tf.float32, name="features"),
  762. ... label_type_spec=tf.TensorSpec(shape=(None,), dtype=tf.float32, name="label")
  763. ... )
  764. <_OptionsDataset element_spec=(TensorSpec(shape=(None, 4), dtype=tf.float32, name='features'), TensorSpec(shape=(None,), dtype=tf.float32, name='label'))>
  765. If your model accepts additional metadata aside from features and label, specify a single additional column or a list of additional columns.
  766. A common use case is to include sample weights in the data samples and train a ``tf.keras.Model`` with ``tf.keras.Model.fit``.
  767. >>> import pandas as pd
  768. >>> ds = ds.add_column("sample weights", lambda df: pd.Series([1] * len(df)))
  769. >>> it = ds.iterator()
  770. >>> it.to_tf(feature_columns="sepal length (cm)", label_columns="target", additional_columns="sample weights")
  771. <_OptionsDataset element_spec=(TensorSpec(shape=(None,), dtype=tf.float64, name='sepal length (cm)'), TensorSpec(shape=(None,), dtype=tf.int64, name='target'), TensorSpec(shape=(None,), dtype=tf.int64, name='sample weights'))>
  772. If your model accepts different types, shapes, or names for the additional metadata, specify the type spec of the additional column.
  773. >>> it.to_tf(
  774. ... feature_columns="sepal length (cm)",
  775. ... label_columns="target",
  776. ... additional_columns="sample weights",
  777. ... additional_type_spec=tf.TensorSpec(shape=(None,), dtype=tf.float32, name="weight")
  778. ... )
  779. <_OptionsDataset element_spec=(TensorSpec(shape=(None,), dtype=tf.float64, name='sepal length (cm)'), TensorSpec(shape=(None,), dtype=tf.int64, name='target'), TensorSpec(shape=(None,), dtype=tf.float32, name='weight'))>
  780. Args:
  781. feature_columns: Columns that correspond to model inputs. If this is a
  782. string, the input data is a tensor. If this is a list, the input data
  783. is a ``dict`` that maps column names to their tensor representation.
  784. label_columns: Columns that correspond to model targets. If this is a
  785. string, the target data is a tensor. If this is a list, the target data
  786. is a ``dict`` that maps column names to their tensor representation.
  787. additional_columns: Columns that correspond to sample weights or other metadata.
  788. If this is a string, the weight data is a tensor. If this is a list, the
  789. weight data is a ``dict`` that maps column names to their tensor representation.
  790. prefetch_batches: The number of batches to fetch ahead of the current batch
  791. to fetch. If set to greater than 0, a separate threadpool will be used
  792. to fetch the objects to the local node, format the batches, and apply
  793. the collate_fn. Defaults to 1.
  794. batch_size: Record batch size. Defaults to 1.
  795. drop_last: Set to True to drop the last incomplete batch,
  796. if the dataset size is not divisible by the batch size. If
  797. False and the size of dataset is not divisible by the batch
  798. size, then the last batch will be smaller. Defaults to False.
  799. local_shuffle_buffer_size: If non-None, the data will be randomly shuffled
  800. using a local in-memory shuffle buffer, and this value will serve as the
  801. minimum number of rows that must be in the local in-memory shuffle
  802. buffer in order to yield a batch. When there are no more rows to add to
  803. the buffer, the remaining rows in the buffer will be drained. This
  804. buffer size must be greater than or equal to ``batch_size``, and
  805. therefore ``batch_size`` must also be specified when using local
  806. shuffling.
  807. local_shuffle_seed: The seed to use for the local random shuffle.
  808. feature_type_spec: The `tf.TypeSpec` of `feature_columns`. If there is
  809. only one column, specify a `tf.TypeSpec`. If there are multiple columns,
  810. specify a ``dict`` that maps column names to their `tf.TypeSpec`.
  811. Default is `None` to automatically infer the type of each column.
  812. label_type_spec: The `tf.TypeSpec` of `label_columns`. If there is
  813. only one column, specify a `tf.TypeSpec`. If there are multiple columns,
  814. specify a ``dict`` that maps column names to their `tf.TypeSpec`.
  815. Default is `None` to automatically infer the type of each column.
  816. additional_type_spec: The `tf.TypeSpec` of `additional_columns`. If there
  817. is only one column, specify a `tf.TypeSpec`. If there are multiple
  818. columns, specify a ``dict`` that maps column names to their `tf.TypeSpec`.
  819. Default is `None` to automatically infer the type of each column.
  820. Returns:
  821. A ``tf.data.Dataset`` that yields inputs and targets.
  822. """ # noqa: E501
  823. from ray.data._internal.utils.tensorflow_utils import (
  824. convert_ndarray_to_tf_tensor,
  825. get_type_spec,
  826. )
  827. try:
  828. import tensorflow as tf
  829. except ImportError:
  830. raise ValueError("tensorflow must be installed!")
  831. def validate_column(column: str) -> None:
  832. if column not in valid_columns:
  833. raise ValueError(
  834. f"You specified '{column}' in `feature_columns`, "
  835. f"`label_columns`, or `additional_columns`, but there's no "
  836. f"column named '{column}' in the dataset. "
  837. f"Valid column names are: {valid_columns}."
  838. )
  839. def validate_columns(columns: Union[str, List]) -> None:
  840. if isinstance(columns, list):
  841. for column in columns:
  842. validate_column(column)
  843. else:
  844. validate_column(columns)
  845. def convert_batch_to_tensors(
  846. batch: Dict[str, np.ndarray],
  847. *,
  848. columns: Union[str, List[str]],
  849. type_spec: Union[tf.TypeSpec, Dict[str, tf.TypeSpec]],
  850. ) -> Union[tf.Tensor, Dict[str, tf.Tensor]]:
  851. if isinstance(columns, str):
  852. return convert_ndarray_to_tf_tensor(batch[columns], type_spec=type_spec)
  853. return {
  854. column: convert_ndarray_to_tf_tensor(
  855. batch[column], type_spec=type_spec[column]
  856. )
  857. for column in columns
  858. }
  859. def generator():
  860. for batch in self._iter_batches(
  861. prefetch_batches=prefetch_batches,
  862. batch_size=batch_size,
  863. drop_last=drop_last,
  864. local_shuffle_buffer_size=local_shuffle_buffer_size,
  865. local_shuffle_seed=local_shuffle_seed,
  866. ):
  867. assert isinstance(batch, dict)
  868. features = convert_batch_to_tensors(
  869. batch, columns=feature_columns, type_spec=feature_type_spec
  870. )
  871. labels = convert_batch_to_tensors(
  872. batch, columns=label_columns, type_spec=label_type_spec
  873. )
  874. if additional_columns is None:
  875. yield features, labels
  876. else:
  877. additional_metadata = convert_batch_to_tensors(
  878. batch,
  879. columns=additional_columns,
  880. type_spec=additional_type_spec,
  881. )
  882. yield features, labels, additional_metadata
  883. if feature_type_spec is None or label_type_spec is None:
  884. schema = self.schema()
  885. valid_columns = set(schema.names)
  886. validate_columns(feature_columns)
  887. validate_columns(label_columns)
  888. feature_type_spec = get_type_spec(schema, columns=feature_columns)
  889. label_type_spec = get_type_spec(schema, columns=label_columns)
  890. if additional_columns is not None and additional_type_spec is None:
  891. schema = self.schema()
  892. valid_columns = set(schema.names)
  893. validate_columns(additional_columns)
  894. additional_type_spec = get_type_spec(schema, columns=additional_columns)
  895. if additional_columns is not None:
  896. dataset = tf.data.Dataset.from_generator(
  897. generator,
  898. output_signature=(
  899. feature_type_spec,
  900. label_type_spec,
  901. additional_type_spec,
  902. ),
  903. )
  904. else:
  905. dataset = tf.data.Dataset.from_generator(
  906. generator, output_signature=(feature_type_spec, label_type_spec)
  907. )
  908. options = tf.data.Options()
  909. options.experimental_distribute.auto_shard_policy = (
  910. tf.data.experimental.AutoShardPolicy.OFF
  911. )
  912. return dataset.with_options(options)
  913. @PublicAPI
  914. def materialize(self) -> "MaterializedDataset":
  915. """Execute and materialize this data iterator into object store memory.
  916. .. note::
  917. This method triggers the execution and materializes all blocks
  918. of the iterator, returning its contents as a
  919. :class:`~ray.data.dataset.MaterializedDataset` for further processing.
  920. """
  921. from ray.data.dataset import MaterializedDataset
  922. ref_bundles_iter, stats, _, _ = self._to_ref_bundle_iterator()
  923. ref_bundles = list(ref_bundles_iter)
  924. execution_plan = ExecutionPlan(stats, self.get_context())
  925. logical_plan = LogicalPlan(
  926. InputData(input_data=ref_bundles),
  927. execution_plan._context,
  928. )
  929. return MaterializedDataset(
  930. execution_plan,
  931. logical_plan,
  932. )
  933. # Backwards compatibility alias.
  934. DatasetIterator = DataIterator