import json import os import pickle import tempfile import time from collections import Counter import numpy as np from ray import tune from ray._private.test_utils import safe_write_to_results_json from ray.tune import Checkpoint from ray.tune.callback import Callback class ProgressCallback(Callback): def __init__(self): self.last_update = 0 self.update_interval = 60 def on_step_end(self, iteration, trials, **kwargs): if time.time() - self.last_update > self.update_interval: now = time.time() result = { "last_update": now, "iteration": iteration, "trial_states": dict(Counter([trial.status for trial in trials])), } safe_write_to_results_json(result, "/tmp/release_test_out.json") self.last_update = now class TestDurableTrainable(tune.Trainable): def __init__(self, *args, **kwargs): self.setup_env() super(TestDurableTrainable, self).__init__(*args, **kwargs) def setup_env(self): pass def setup(self, config): self._num_iters = int(config["num_iters"]) self._sleep_time = config["sleep_time"] self._score = config["score"] self._checkpoint_iters = config["checkpoint_iters"] self._checkpoint_size_b = config["checkpoint_size_b"] self._checkpoint_num_items = self._checkpoint_size_b // 8 # np.float64 self._iter = 0 def step(self): if self._iter > 0: time.sleep(self._sleep_time) res = dict(score=self._iter + self._score) if self._iter >= self._num_iters: res["done"] = True self._iter += 1 return res def save_checkpoint(self, tmp_checkpoint_dir): checkpoint_file = os.path.join(tmp_checkpoint_dir, "bogus.ckpt") checkpoint_data = np.random.uniform(0, 1, size=self._checkpoint_num_items) with open(checkpoint_file, "wb") as fp: pickle.dump(checkpoint_data, fp) def load_checkpoint(self, checkpoint): pass def function_trainable(config): num_iters = int(config["num_iters"]) sleep_time = config["sleep_time"] score = config["score"] checkpoint_iters = config["checkpoint_iters"] checkpoint_size_b = config["checkpoint_size_b"] checkpoint_num_items = checkpoint_size_b // 8 # np.float64 checkpoint_num_files = config["checkpoint_num_files"] for i in range(num_iters): metrics = {"score": i + score} if ( checkpoint_iters >= 0 and checkpoint_size_b > 0 and i % checkpoint_iters == 0 ): with tempfile.TemporaryDirectory() as tmpdir: for i in range(checkpoint_num_files): checkpoint_file = os.path.join(tmpdir, f"bogus_{i}.ckpt") checkpoint_data = np.random.uniform(0, 1, size=checkpoint_num_items) with open(checkpoint_file, "wb") as fp: pickle.dump(checkpoint_data, fp) tune.report(metrics, checkpoint=Checkpoint.from_directory(tmpdir)) else: tune.report(metrics) time.sleep(sleep_time) def timed_tune_run( name: str, num_samples: int, results_per_second: int = 1, trial_length_s: int = 1, max_runtime: int = 300, checkpoint_freq_s: int = -1, checkpoint_size_b: int = 0, checkpoint_num_files: int = 1, **tune_kwargs, ) -> bool: durable = ( "storage_path" in tune_kwargs and tune_kwargs["storage_path"] and ( tune_kwargs["storage_path"].startswith("s3://") or tune_kwargs["storage_path"].startswith("gs://") ) ) sleep_time = 1.0 / results_per_second num_iters = int(trial_length_s / sleep_time) checkpoint_iters = -1 if checkpoint_freq_s >= 0: checkpoint_iters = int(checkpoint_freq_s / sleep_time) config = { "score": tune.uniform(0.0, 1.0), "num_iters": num_iters, "sleep_time": sleep_time, "checkpoint_iters": checkpoint_iters, "checkpoint_size_b": checkpoint_size_b, "checkpoint_num_files": checkpoint_num_files, } print(f"Starting benchmark with config: {config}") run_kwargs = {"reuse_actors": True, "verbose": 2} run_kwargs.update(tune_kwargs) _train = function_trainable if durable: _train = TestDurableTrainable run_kwargs["checkpoint_freq"] = checkpoint_iters start_time = time.monotonic() analysis = tune.run( _train, config=config, num_samples=num_samples, raise_on_failed_trial=False, **run_kwargs, ) time_taken = time.monotonic() - start_time result = { "time_taken": time_taken, "trial_states": dict(Counter([trial.status for trial in analysis.trials])), "last_update": time.time(), } test_output_json = os.environ.get("TEST_OUTPUT_JSON", "/tmp/tune_test.json") with open(test_output_json, "wt") as f: json.dump(result, f) success = time_taken <= max_runtime if not success: print( f"The {name} test took {time_taken:.2f} seconds, but should not " f"have exceeded {max_runtime:.2f} seconds. Test failed. \n\n" f"--- FAILED: {name.upper()} ::: " f"{time_taken:.2f} > {max_runtime:.2f} ---" ) else: print( f"The {name} test took {time_taken:.2f} seconds, which " f"is below the budget of {max_runtime:.2f} seconds. " f"Test successful. \n\n" f"--- PASSED: {name.upper()} ::: " f"{time_taken:.2f} <= {max_runtime:.2f} ---" ) return success