logging_example.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. #!/usr/bin/env python
  2. import argparse
  3. import time
  4. from ray import tune
  5. from ray.tune.logger import LoggerCallback
  6. class TestLoggerCallback(LoggerCallback):
  7. def on_trial_result(self, iteration, trials, trial, result, **info):
  8. print(f"TestLogger for trial {trial}: {result}")
  9. def trial_str_creator(trial):
  10. return "{}_{}_123".format(trial.trainable_name, trial.trial_id)
  11. def evaluation_fn(step, width, height):
  12. time.sleep(0.1)
  13. return (0.1 + width * step / 100) ** (-1) + height * 0.1
  14. def easy_objective(config):
  15. # Hyperparameters
  16. width, height = config["width"], config["height"]
  17. for step in range(config["steps"]):
  18. # Iterative training function - can be any arbitrary training procedure
  19. intermediate_score = evaluation_fn(step, width, height)
  20. # Feed the score back back to Tune.
  21. tune.report({"iterations": step, "mean_loss": intermediate_score})
  22. if __name__ == "__main__":
  23. parser = argparse.ArgumentParser()
  24. parser.add_argument(
  25. "--smoke-test", action="store_true", help="Finish quickly for testing"
  26. )
  27. args, _ = parser.parse_known_args()
  28. tuner = tune.Tuner(
  29. easy_objective,
  30. run_config=tune.RunConfig(
  31. name="hyperband_test",
  32. callbacks=[TestLoggerCallback()],
  33. stop={"training_iteration": 1 if args.smoke_test else 100},
  34. ),
  35. tune_config=tune.TuneConfig(
  36. metric="mean_loss",
  37. mode="min",
  38. num_samples=5,
  39. trial_name_creator=trial_str_creator,
  40. trial_dirname_creator=trial_str_creator,
  41. ),
  42. param_space={
  43. "steps": 100,
  44. "width": tune.randint(10, 100),
  45. "height": tune.loguniform(10, 100),
  46. },
  47. )
  48. results = tuner.fit()
  49. print("Best hyperparameters: ", results.get_best_result().config)