ax_example.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. """This example demonstrates the usage of AxSearch with Ray Tune.
  2. It also checks that it is usable with a separate scheduler.
  3. Requires the Ax library to be installed (`pip install ax-platform`).
  4. """
  5. import time
  6. import numpy as np
  7. from ray import tune
  8. from ray.tune.schedulers import AsyncHyperBandScheduler
  9. from ray.tune.search.ax import AxSearch
  10. def hartmann6(x):
  11. alpha = np.array([1.0, 1.2, 3.0, 3.2])
  12. A = np.array(
  13. [
  14. [10, 3, 17, 3.5, 1.7, 8],
  15. [0.05, 10, 17, 0.1, 8, 14],
  16. [3, 3.5, 1.7, 10, 17, 8],
  17. [17, 8, 0.05, 10, 0.1, 14],
  18. ]
  19. )
  20. P = 10 ** (-4) * np.array(
  21. [
  22. [1312, 1696, 5569, 124, 8283, 5886],
  23. [2329, 4135, 8307, 3736, 1004, 9991],
  24. [2348, 1451, 3522, 2883, 3047, 6650],
  25. [4047, 8828, 8732, 5743, 1091, 381],
  26. ]
  27. )
  28. y = 0.0
  29. for j, alpha_j in enumerate(alpha):
  30. t = 0
  31. for k in range(6):
  32. t += A[j, k] * ((x[k] - P[j, k]) ** 2)
  33. y -= alpha_j * np.exp(-t)
  34. return y
  35. def easy_objective(config):
  36. for i in range(config["iterations"]):
  37. x = np.array([config.get("x{}".format(i + 1)) for i in range(6)])
  38. tune.report(
  39. {
  40. "timesteps_total": i,
  41. "hartmann6": hartmann6(x),
  42. "l2norm": np.sqrt((x**2).sum()),
  43. }
  44. )
  45. time.sleep(0.02)
  46. if __name__ == "__main__":
  47. import argparse
  48. parser = argparse.ArgumentParser()
  49. parser.add_argument(
  50. "--smoke-test", action="store_true", help="Finish quickly for testing"
  51. )
  52. args, _ = parser.parse_known_args()
  53. algo = AxSearch(
  54. parameter_constraints=["x1 + x2 <= 2.0"], # Optional.
  55. outcome_constraints=["l2norm <= 1.25"], # Optional.
  56. )
  57. # Limit to 4 concurrent trials
  58. algo = tune.search.ConcurrencyLimiter(algo, max_concurrent=4)
  59. scheduler = AsyncHyperBandScheduler()
  60. tuner = tune.Tuner(
  61. easy_objective,
  62. run_config=tune.RunConfig(
  63. name="ax",
  64. stop={"timesteps_total": 100},
  65. ),
  66. tune_config=tune.TuneConfig(
  67. metric="hartmann6", # provided in the 'easy_objective' function
  68. mode="min",
  69. search_alg=algo,
  70. scheduler=scheduler,
  71. num_samples=10 if args.smoke_test else 50,
  72. ),
  73. param_space={
  74. "iterations": 100,
  75. "x1": tune.uniform(0.0, 1.0),
  76. "x2": tune.uniform(0.0, 1.0),
  77. "x3": tune.uniform(0.0, 1.0),
  78. "x4": tune.uniform(0.0, 1.0),
  79. "x5": tune.uniform(0.0, 1.0),
  80. "x6": tune.uniform(0.0, 1.0),
  81. },
  82. )
  83. results = tuner.fit()
  84. print("Best hyperparameters found were: ", results.get_best_result().config)