tf_mnist_example.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. #
  4. # This example showcases how to use TF2.0 APIs with Tune.
  5. # Original code: https://www.tensorflow.org/tutorials/quickstart/advanced
  6. #
  7. # As of 10/12/2019: One caveat of using TF2.0 is that TF AutoGraph
  8. # functionality does not interact nicely with Ray actors. One way to get around
  9. # this is to `import tensorflow` inside the Tune Trainable.
  10. #
  11. import argparse
  12. import os
  13. import sys
  14. from filelock import FileLock
  15. from ray import tune
  16. MAX_TRAIN_BATCH = 10
  17. if sys.version_info >= (3, 12):
  18. # Tensorflow is not installed for Python 3.12 because of keras compatibility.
  19. sys.exit(0)
  20. else:
  21. from tensorflow.keras import Model
  22. from tensorflow.keras.datasets.mnist import load_data
  23. from tensorflow.keras.layers import Conv2D, Dense, Flatten
  24. class MyModel(Model):
  25. def __init__(self, hiddens=128):
  26. super(MyModel, self).__init__()
  27. self.conv1 = Conv2D(32, 3, activation="relu")
  28. self.flatten = Flatten()
  29. self.d1 = Dense(hiddens, activation="relu")
  30. self.d2 = Dense(10, activation="softmax")
  31. def call(self, x):
  32. x = self.conv1(x)
  33. x = self.flatten(x)
  34. x = self.d1(x)
  35. return self.d2(x)
  36. class MNISTTrainable(tune.Trainable):
  37. def setup(self, config):
  38. # IMPORTANT: See the above note.
  39. import tensorflow as tf
  40. # Use FileLock to avoid race conditions.
  41. with FileLock(os.path.expanduser("~/.tune.lock")):
  42. (x_train, y_train), (x_test, y_test) = load_data()
  43. x_train, x_test = x_train / 255.0, x_test / 255.0
  44. # Add a channels dimension
  45. x_train = x_train[..., tf.newaxis]
  46. x_test = x_test[..., tf.newaxis]
  47. self.train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
  48. self.train_ds = self.train_ds.shuffle(10000).batch(config.get("batch", 32))
  49. self.test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
  50. self.model = MyModel(hiddens=config.get("hiddens", 128))
  51. self.loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
  52. self.optimizer = tf.keras.optimizers.Adam()
  53. self.train_loss = tf.keras.metrics.Mean(name="train_loss")
  54. self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
  55. name="train_accuracy"
  56. )
  57. self.test_loss = tf.keras.metrics.Mean(name="test_loss")
  58. self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
  59. name="test_accuracy"
  60. )
  61. @tf.function
  62. def train_step(images, labels):
  63. with tf.GradientTape() as tape:
  64. predictions = self.model(images)
  65. loss = self.loss_object(labels, predictions)
  66. gradients = tape.gradient(loss, self.model.trainable_variables)
  67. self.optimizer.apply_gradients(
  68. zip(gradients, self.model.trainable_variables)
  69. )
  70. self.train_loss(loss)
  71. self.train_accuracy(labels, predictions)
  72. @tf.function
  73. def test_step(images, labels):
  74. predictions = self.model(images)
  75. t_loss = self.loss_object(labels, predictions)
  76. self.test_loss(t_loss)
  77. self.test_accuracy(labels, predictions)
  78. self.tf_train_step = train_step
  79. self.tf_test_step = test_step
  80. def save_checkpoint(self, checkpoint_dir: str):
  81. return None
  82. def load_checkpoint(self, checkpoint):
  83. return None
  84. def step(self):
  85. self.train_loss.reset_states()
  86. self.train_accuracy.reset_states()
  87. self.test_loss.reset_states()
  88. self.test_accuracy.reset_states()
  89. for idx, (images, labels) in enumerate(self.train_ds):
  90. if idx > MAX_TRAIN_BATCH: # This is optional and can be removed.
  91. break
  92. self.tf_train_step(images, labels)
  93. for test_images, test_labels in self.test_ds:
  94. self.tf_test_step(test_images, test_labels)
  95. # It is important to return tf.Tensors as numpy objects.
  96. return {
  97. "epoch": self.iteration,
  98. "loss": self.train_loss.result().numpy(),
  99. "accuracy": self.train_accuracy.result().numpy() * 100,
  100. "test_loss": self.test_loss.result().numpy(),
  101. "mean_accuracy": self.test_accuracy.result().numpy() * 100,
  102. }
  103. if __name__ == "__main__":
  104. parser = argparse.ArgumentParser()
  105. parser.add_argument(
  106. "--smoke-test", action="store_true", help="Finish quickly for testing"
  107. )
  108. args, _ = parser.parse_known_args()
  109. tuner = tune.Tuner(
  110. MNISTTrainable,
  111. tune_config=tune.TuneConfig(
  112. metric="test_loss",
  113. mode="min",
  114. ),
  115. run_config=tune.RunConfig(
  116. stop={"training_iteration": 5 if args.smoke_test else 50},
  117. verbose=1,
  118. ),
  119. param_space={"hiddens": tune.grid_search([32, 64, 128])},
  120. )
  121. results = tuner.fit()
  122. print("Best hyperparameters found were: ", results.get_best_result().config)