| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576 |
- #!/usr/bin/env python
- import argparse
- import json
- import os
- import tempfile
- import numpy as np
- import ray
- from ray import tune
- from ray.tune import Checkpoint
- from ray.tune.schedulers import HyperBandScheduler
- def train_func(config):
- step = 0
- checkpoint = tune.get_checkpoint()
- if checkpoint:
- with checkpoint.as_directory() as checkpoint_dir:
- with open(os.path.join(checkpoint_dir, "checkpoint.json")) as f:
- step = json.load(f)["timestep"] + 1
- for timestep in range(step, 100):
- v = np.tanh(float(timestep) / config.get("width", 1))
- v *= config.get("height", 1)
- # Checkpoint the state of the training every 3 steps
- # Note that this is only required for certain schedulers
- with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
- checkpoint = None
- if timestep % 3 == 0:
- with open(
- os.path.join(temp_checkpoint_dir, "checkpoint.json"), "w"
- ) as f:
- json.dump({"timestep": timestep}, f)
- checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
- # Here we use `episode_reward_mean`, but you can also report other
- # objectives such as loss or accuracy.
- tune.report({"episode_reward_mean": v}, checkpoint=checkpoint)
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--smoke-test", action="store_true", help="Finish quickly for testing"
- )
- args, _ = parser.parse_known_args()
- ray.init(num_cpus=4 if args.smoke_test else None)
- # Hyperband early stopping, configured with `episode_reward_mean` as the
- # objective and `training_iteration` as the time unit,
- # which is automatically filled by Tune.
- hyperband = HyperBandScheduler(max_t=200)
- tuner = tune.Tuner(
- train_func,
- run_config=tune.RunConfig(
- name="hyperband_test",
- stop={"training_iteration": 10 if args.smoke_test else 99999},
- failure_config=tune.FailureConfig(
- fail_fast=True,
- ),
- ),
- tune_config=tune.TuneConfig(
- num_samples=20,
- metric="episode_reward_mean",
- mode="max",
- scheduler=hyperband,
- ),
- param_space={"height": tune.uniform(0, 100)},
- )
- results = tuner.fit()
- print("Best hyperparameters found were: ", results.get_best_result().config)
|