#!/usr/bin/env python import argparse import json import os import tempfile import numpy as np import ray from ray import tune from ray.tune import Checkpoint from ray.tune.schedulers import HyperBandScheduler def train_func(config): step = 0 checkpoint = tune.get_checkpoint() if checkpoint: with checkpoint.as_directory() as checkpoint_dir: with open(os.path.join(checkpoint_dir, "checkpoint.json")) as f: step = json.load(f)["timestep"] + 1 for timestep in range(step, 100): v = np.tanh(float(timestep) / config.get("width", 1)) v *= config.get("height", 1) # Checkpoint the state of the training every 3 steps # Note that this is only required for certain schedulers with tempfile.TemporaryDirectory() as temp_checkpoint_dir: checkpoint = None if timestep % 3 == 0: with open( os.path.join(temp_checkpoint_dir, "checkpoint.json"), "w" ) as f: json.dump({"timestep": timestep}, f) checkpoint = Checkpoint.from_directory(temp_checkpoint_dir) # Here we use `episode_reward_mean`, but you can also report other # objectives such as loss or accuracy. tune.report({"episode_reward_mean": v}, checkpoint=checkpoint) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--smoke-test", action="store_true", help="Finish quickly for testing" ) args, _ = parser.parse_known_args() ray.init(num_cpus=4 if args.smoke_test else None) # Hyperband early stopping, configured with `episode_reward_mean` as the # objective and `training_iteration` as the time unit, # which is automatically filled by Tune. hyperband = HyperBandScheduler(max_t=200) tuner = tune.Tuner( train_func, run_config=tune.RunConfig( name="hyperband_test", stop={"training_iteration": 10 if args.smoke_test else 99999}, failure_config=tune.FailureConfig( fail_fast=True, ), ), tune_config=tune.TuneConfig( num_samples=20, metric="episode_reward_mean", mode="max", scheduler=hyperband, ), param_space={"height": tune.uniform(0, 100)}, ) results = tuner.fit() print("Best hyperparameters found were: ", results.get_best_result().config)