| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152 |
- #!/usr/bin/env python
- # coding: utf-8
- #
- # This example showcases how to use TF2.0 APIs with Tune.
- # Original code: https://www.tensorflow.org/tutorials/quickstart/advanced
- #
- # As of 10/12/2019: One caveat of using TF2.0 is that TF AutoGraph
- # functionality does not interact nicely with Ray actors. One way to get around
- # this is to `import tensorflow` inside the Tune Trainable.
- #
- import argparse
- import os
- import sys
- from filelock import FileLock
- from ray import tune
- MAX_TRAIN_BATCH = 10
- if sys.version_info >= (3, 12):
- # Tensorflow is not installed for Python 3.12 because of keras compatibility.
- sys.exit(0)
- else:
- from tensorflow.keras import Model
- from tensorflow.keras.datasets.mnist import load_data
- from tensorflow.keras.layers import Conv2D, Dense, Flatten
- class MyModel(Model):
- def __init__(self, hiddens=128):
- super(MyModel, self).__init__()
- self.conv1 = Conv2D(32, 3, activation="relu")
- self.flatten = Flatten()
- self.d1 = Dense(hiddens, activation="relu")
- self.d2 = Dense(10, activation="softmax")
- def call(self, x):
- x = self.conv1(x)
- x = self.flatten(x)
- x = self.d1(x)
- return self.d2(x)
- class MNISTTrainable(tune.Trainable):
- def setup(self, config):
- # IMPORTANT: See the above note.
- import tensorflow as tf
- # Use FileLock to avoid race conditions.
- with FileLock(os.path.expanduser("~/.tune.lock")):
- (x_train, y_train), (x_test, y_test) = load_data()
- x_train, x_test = x_train / 255.0, x_test / 255.0
- # Add a channels dimension
- x_train = x_train[..., tf.newaxis]
- x_test = x_test[..., tf.newaxis]
- self.train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
- self.train_ds = self.train_ds.shuffle(10000).batch(config.get("batch", 32))
- self.test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
- self.model = MyModel(hiddens=config.get("hiddens", 128))
- self.loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
- self.optimizer = tf.keras.optimizers.Adam()
- self.train_loss = tf.keras.metrics.Mean(name="train_loss")
- self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
- name="train_accuracy"
- )
- self.test_loss = tf.keras.metrics.Mean(name="test_loss")
- self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
- name="test_accuracy"
- )
- @tf.function
- def train_step(images, labels):
- with tf.GradientTape() as tape:
- predictions = self.model(images)
- loss = self.loss_object(labels, predictions)
- gradients = tape.gradient(loss, self.model.trainable_variables)
- self.optimizer.apply_gradients(
- zip(gradients, self.model.trainable_variables)
- )
- self.train_loss(loss)
- self.train_accuracy(labels, predictions)
- @tf.function
- def test_step(images, labels):
- predictions = self.model(images)
- t_loss = self.loss_object(labels, predictions)
- self.test_loss(t_loss)
- self.test_accuracy(labels, predictions)
- self.tf_train_step = train_step
- self.tf_test_step = test_step
- def save_checkpoint(self, checkpoint_dir: str):
- return None
- def load_checkpoint(self, checkpoint):
- return None
- def step(self):
- self.train_loss.reset_states()
- self.train_accuracy.reset_states()
- self.test_loss.reset_states()
- self.test_accuracy.reset_states()
- for idx, (images, labels) in enumerate(self.train_ds):
- if idx > MAX_TRAIN_BATCH: # This is optional and can be removed.
- break
- self.tf_train_step(images, labels)
- for test_images, test_labels in self.test_ds:
- self.tf_test_step(test_images, test_labels)
- # It is important to return tf.Tensors as numpy objects.
- return {
- "epoch": self.iteration,
- "loss": self.train_loss.result().numpy(),
- "accuracy": self.train_accuracy.result().numpy() * 100,
- "test_loss": self.test_loss.result().numpy(),
- "mean_accuracy": self.test_accuracy.result().numpy() * 100,
- }
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--smoke-test", action="store_true", help="Finish quickly for testing"
- )
- args, _ = parser.parse_known_args()
- tuner = tune.Tuner(
- MNISTTrainable,
- tune_config=tune.TuneConfig(
- metric="test_loss",
- mode="min",
- ),
- run_config=tune.RunConfig(
- stop={"training_iteration": 5 if args.smoke_test else 50},
- verbose=1,
- ),
- param_space={"hiddens": tune.grid_search([32, 64, 128])},
- )
- results = tuner.fit()
- print("Best hyperparameters found were: ", results.get_best_result().config)
|