optuna_multiobjective_example.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. """This example demonstrates the usage of Optuna with Ray Tune for
  2. multi-objective optimization.
  3. Please note that schedulers may not work correctly with multi-objective
  4. optimization.
  5. Requires the Optuna library to be installed (`pip install optuna`).
  6. """
  7. import time
  8. import ray
  9. from ray import tune
  10. from ray.tune.search import ConcurrencyLimiter
  11. from ray.tune.search.optuna import OptunaSearch
  12. def evaluation_fn(step, width, height):
  13. return (0.1 + width * step / 100) ** (-1) + height * 0.1
  14. def easy_objective(config):
  15. # Hyperparameters
  16. width, height = config["width"], config["height"]
  17. for step in range(config["steps"]):
  18. # Iterative training function - can be any arbitrary training procedure
  19. intermediate_score = evaluation_fn(step, width, height)
  20. # Feed the score back back to Tune.
  21. tune.report(
  22. {
  23. "iterations": step,
  24. "loss": intermediate_score,
  25. "gain": intermediate_score * width,
  26. }
  27. )
  28. time.sleep(0.1)
  29. def run_optuna_tune(smoke_test=False):
  30. algo = OptunaSearch(metric=["loss", "gain"], mode=["min", "max"])
  31. algo = ConcurrencyLimiter(algo, max_concurrent=4)
  32. tuner = tune.Tuner(
  33. easy_objective,
  34. tune_config=tune.TuneConfig(
  35. search_alg=algo,
  36. num_samples=10 if smoke_test else 100,
  37. ),
  38. param_space={
  39. "steps": 100,
  40. "width": tune.uniform(0, 20),
  41. "height": tune.uniform(-100, 100),
  42. # This is an ignored parameter.
  43. "activation": tune.choice(["relu", "tanh"]),
  44. },
  45. )
  46. results = tuner.fit()
  47. print(
  48. "Best hyperparameters for loss found were: ",
  49. results.get_best_result("loss", "min").config,
  50. )
  51. print(
  52. "Best hyperparameters for gain found were: ",
  53. results.get_best_result("gain", "max").config,
  54. )
  55. if __name__ == "__main__":
  56. import argparse
  57. parser = argparse.ArgumentParser()
  58. parser.add_argument(
  59. "--smoke-test", action="store_true", help="Finish quickly for testing"
  60. )
  61. args, _ = parser.parse_known_args()
  62. ray.init(configure_logging=False)
  63. run_optuna_tune(smoke_test=args.smoke_test)