| 1234567891011121314151617181920212223242526272829303132333435363738394041424344 |
- #!/usr/bin/env python
- import argparse
- import ray
- from ray import tune
- from ray.tune.schedulers import HyperBandScheduler
- from ray.tune.utils.mock_trainable import MyTrainableClass
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--smoke-test", action="store_true", help="Finish quickly for testing"
- )
- args, _ = parser.parse_known_args()
- ray.init(num_cpus=4 if args.smoke_test else None)
- # Hyperband early stopping, configured with `episode_reward_mean` as the
- # objective and `training_iteration` as the time unit,
- # which is automatically filled by Tune.
- hyperband = HyperBandScheduler(time_attr="training_iteration", max_t=200)
- tuner = tune.Tuner(
- MyTrainableClass,
- run_config=tune.RunConfig(
- name="hyperband_test",
- stop={"training_iteration": 1 if args.smoke_test else 200},
- verbose=1,
- failure_config=tune.FailureConfig(
- fail_fast=True,
- ),
- ),
- tune_config=tune.TuneConfig(
- num_samples=20 if args.smoke_test else 200,
- metric="episode_reward_mean",
- mode="max",
- scheduler=hyperband,
- ),
- param_space={"width": tune.randint(10, 90), "height": tune.randint(0, 100)},
- )
- results = tuner.fit()
- print("Best hyperparameters found were: ", results.get_best_result().config)
|