test_utils.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. from sklearn.datasets import load_breast_cancer
  2. from ray import tune
  3. from ray.data import Dataset, Datasource, ReadTask, read_datasource
  4. from ray.data.block import BlockMetadata
  5. from ray.tune.impl.utils import execute_dataset
  6. # TODO(xwjiang): Enable this when Clark's out-of-band-serialization is landed.
  7. class TestDatasource(Datasource):
  8. def prepare_read(self, parallelism: int, **read_args):
  9. import pyarrow as pa
  10. def load_data():
  11. data_raw = load_breast_cancer(as_frame=True)
  12. dataset_df = data_raw["data"]
  13. dataset_df["target"] = data_raw["target"]
  14. return [pa.Table.from_pandas(dataset_df)]
  15. meta = BlockMetadata(
  16. num_rows=None,
  17. size_bytes=None,
  18. input_files=None,
  19. exec_stats=None,
  20. )
  21. return [ReadTask(load_data, meta)]
  22. def gen_dataset_func() -> Dataset:
  23. test_datasource = TestDatasource()
  24. return read_datasource(test_datasource)
  25. def test_grid_search():
  26. ds1 = gen_dataset_func().lazy().map(lambda x: x)
  27. ds2 = gen_dataset_func().lazy().map(lambda x: x)
  28. assert not ds1._plan._has_final_stage_snapshot()
  29. assert not ds2._plan._has_final_stage_snapshot()
  30. param_space = {"train_dataset": tune.grid_search([ds1, ds2])}
  31. execute_dataset(param_space)
  32. executed_ds = param_space["train_dataset"]["grid_search"]
  33. assert len(executed_ds) == 2
  34. assert executed_ds[0]._plan._has_final_stage_snapshot()
  35. assert executed_ds[1]._plan._has_final_stage_snapshot()
  36. def test_choice():
  37. ds1 = gen_dataset_func().lazy().map(lambda x: x)
  38. ds2 = gen_dataset_func().lazy().map(lambda x: x)
  39. assert not ds1._plan._has_final_stage_snapshot()
  40. assert not ds2._plan._has_final_stage_snapshot()
  41. param_space = {"train_dataset": tune.choice([ds1, ds2])}
  42. execute_dataset(param_space)
  43. executed_ds = param_space["train_dataset"].categories
  44. assert len(executed_ds) == 2
  45. assert executed_ds[0]._plan._has_final_stage_snapshot()
  46. assert executed_ds[1]._plan._has_final_stage_snapshot()
  47. if __name__ == "__main__":
  48. import sys
  49. import pytest
  50. sys.exit(pytest.main(["-v", "-x", __file__]))