tune_mnist_keras.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import argparse
  2. import os
  3. import sys
  4. from filelock import FileLock
  5. import ray
  6. from ray import tune
  7. from ray.tune.schedulers import AsyncHyperBandScheduler
  8. if sys.version_info >= (3, 12):
  9. # Tensorflow is not installed for Python 3.12 because of keras compatibility.
  10. sys.exit(0)
  11. else:
  12. from tensorflow.keras.datasets import mnist
  13. from ray.tune.integration.keras import TuneReportCheckpointCallback
  14. def train_mnist(config):
  15. # https://github.com/tensorflow/tensorflow/issues/32159
  16. import tensorflow as tf
  17. batch_size = 128
  18. num_classes = 10
  19. epochs = 12
  20. with FileLock(os.path.expanduser("~/.data.lock")):
  21. (x_train, y_train), (x_test, y_test) = mnist.load_data()
  22. x_train, x_test = x_train / 255.0, x_test / 255.0
  23. model = tf.keras.models.Sequential(
  24. [
  25. tf.keras.layers.Flatten(input_shape=(28, 28)),
  26. tf.keras.layers.Dense(config["hidden"], activation="relu"),
  27. tf.keras.layers.Dropout(0.2),
  28. tf.keras.layers.Dense(num_classes, activation="softmax"),
  29. ]
  30. )
  31. model.compile(
  32. loss="sparse_categorical_crossentropy",
  33. optimizer=tf.keras.optimizers.SGD(lr=config["lr"], momentum=config["momentum"]),
  34. metrics=["accuracy"],
  35. )
  36. model.fit(
  37. x_train,
  38. y_train,
  39. batch_size=batch_size,
  40. epochs=epochs,
  41. verbose=0,
  42. validation_data=(x_test, y_test),
  43. callbacks=[
  44. TuneReportCheckpointCallback(
  45. checkpoint_on=[], metrics={"mean_accuracy": "accuracy"}
  46. )
  47. ],
  48. )
  49. def tune_mnist(num_training_iterations):
  50. sched = AsyncHyperBandScheduler(
  51. time_attr="training_iteration", max_t=400, grace_period=20
  52. )
  53. tuner = tune.Tuner(
  54. tune.with_resources(train_mnist, resources={"cpu": 2, "gpu": 0}),
  55. run_config=tune.RunConfig(
  56. name="exp",
  57. stop={"mean_accuracy": 0.99, "training_iteration": num_training_iterations},
  58. ),
  59. tune_config=tune.TuneConfig(
  60. scheduler=sched,
  61. metric="mean_accuracy",
  62. mode="max",
  63. num_samples=10,
  64. ),
  65. param_space={
  66. "threads": 2,
  67. "lr": tune.uniform(0.001, 0.1),
  68. "momentum": tune.uniform(0.1, 0.9),
  69. "hidden": tune.randint(32, 512),
  70. },
  71. )
  72. results = tuner.fit()
  73. print("Best hyperparameters found were: ", results.get_best_result().config)
  74. if __name__ == "__main__":
  75. parser = argparse.ArgumentParser()
  76. parser.add_argument(
  77. "--smoke-test", action="store_true", help="Finish quickly for testing"
  78. )
  79. args, _ = parser.parse_known_args()
  80. if args.smoke_test:
  81. ray.init(num_cpus=4)
  82. tune_mnist(num_training_iterations=2 if args.smoke_test else 300)