| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465 |
- from sklearn.datasets import load_breast_cancer
- from ray import tune
- from ray.data import Dataset, Datasource, ReadTask, read_datasource
- from ray.data.block import BlockMetadata
- from ray.tune.impl.utils import execute_dataset
- # TODO(xwjiang): Enable this when Clark's out-of-band-serialization is landed.
- class TestDatasource(Datasource):
- def prepare_read(self, parallelism: int, **read_args):
- import pyarrow as pa
- def load_data():
- data_raw = load_breast_cancer(as_frame=True)
- dataset_df = data_raw["data"]
- dataset_df["target"] = data_raw["target"]
- return [pa.Table.from_pandas(dataset_df)]
- meta = BlockMetadata(
- num_rows=None,
- size_bytes=None,
- input_files=None,
- exec_stats=None,
- )
- return [ReadTask(load_data, meta)]
- def gen_dataset_func() -> Dataset:
- test_datasource = TestDatasource()
- return read_datasource(test_datasource)
- def test_grid_search():
- ds1 = gen_dataset_func().lazy().map(lambda x: x)
- ds2 = gen_dataset_func().lazy().map(lambda x: x)
- assert not ds1._plan._has_final_stage_snapshot()
- assert not ds2._plan._has_final_stage_snapshot()
- param_space = {"train_dataset": tune.grid_search([ds1, ds2])}
- execute_dataset(param_space)
- executed_ds = param_space["train_dataset"]["grid_search"]
- assert len(executed_ds) == 2
- assert executed_ds[0]._plan._has_final_stage_snapshot()
- assert executed_ds[1]._plan._has_final_stage_snapshot()
- def test_choice():
- ds1 = gen_dataset_func().lazy().map(lambda x: x)
- ds2 = gen_dataset_func().lazy().map(lambda x: x)
- assert not ds1._plan._has_final_stage_snapshot()
- assert not ds2._plan._has_final_stage_snapshot()
- param_space = {"train_dataset": tune.choice([ds1, ds2])}
- execute_dataset(param_space)
- executed_ds = param_space["train_dataset"].categories
- assert len(executed_ds) == 2
- assert executed_ds[0]._plan._has_final_stage_snapshot()
- assert executed_ds[1]._plan._has_final_stage_snapshot()
- if __name__ == "__main__":
- import sys
- import pytest
- sys.exit(pytest.main(["-v", "-x", __file__]))
|