| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162 |
- #!/usr/bin/env python
- import argparse
- import ray
- from ray import tune
- from ray.tune.examples.pbt_function import pbt_function
- from ray.tune.schedulers.pb2 import PB2
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--smoke-test", action="store_true", help="Finish quickly for testing"
- )
- args, _ = parser.parse_known_args()
- if args.smoke_test:
- ray.init(num_cpus=2) # force pausing to happen for test
- perturbation_interval = 5
- pbt = PB2(
- time_attr="training_iteration",
- perturbation_interval=perturbation_interval,
- hyperparam_bounds={
- # hyperparameter bounds.
- "lr": [0.0001, 0.02],
- },
- )
- tuner = tune.Tuner(
- pbt_function,
- run_config=tune.RunConfig(
- name="pbt_test",
- verbose=False,
- stop={
- "training_iteration": 30,
- },
- failure_config=tune.FailureConfig(
- fail_fast=True,
- ),
- ),
- tune_config=tune.TuneConfig(
- scheduler=pbt,
- metric="mean_accuracy",
- mode="max",
- num_samples=8,
- reuse_actors=True,
- ),
- param_space={
- "lr": 0.0001,
- # note: this parameter is perturbed but has no effect on
- # the model training in this example
- "some_other_factor": 1,
- # This parameter is not perturbed and is used to determine
- # checkpoint frequency. We set checkpoints and perturbations
- # to happen at the same frequency.
- "checkpoint_interval": perturbation_interval,
- },
- )
- results = tuner.fit()
- print("Best hyperparameters found were: ", results.get_best_result().config)
|