"""An example showing how to use Pytorch Lightning training, Ray Tune HPO, and MLflow autologging all together.""" import os import tempfile import mlflow import pytorch_lightning as pl from ray import tune from ray.air.integrations.mlflow import setup_mlflow from ray.tune.examples.mnist_ptl_mini import LightningMNISTClassifier, MNISTDataModule from ray.tune.integration.pytorch_lightning import TuneReportCallback def train_mnist_tune(config, data_dir=None, num_epochs=10, num_gpus=0): setup_mlflow( config, experiment_name=config.get("experiment_name", None), tracking_uri=config.get("tracking_uri", None), ) model = LightningMNISTClassifier(config, data_dir) dm = MNISTDataModule( data_dir=data_dir, num_workers=1, batch_size=config["batch_size"] ) metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"} mlflow.pytorch.autolog() trainer = pl.Trainer( max_epochs=num_epochs, gpus=num_gpus, progress_bar_refresh_rate=0, callbacks=[TuneReportCallback(metrics, on="validation_end")], ) trainer.fit(model, dm) def tune_mnist( num_samples=10, num_epochs=10, gpus_per_trial=0, tracking_uri=None, experiment_name="ptl_autologging_example", ): data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_") # Download data MNISTDataModule(data_dir=data_dir, batch_size=32).prepare_data() # Set the MLflow experiment, or create it if it does not exist. mlflow.set_tracking_uri(tracking_uri) mlflow.set_experiment(experiment_name) config = { "layer_1": tune.choice([32, 64, 128]), "layer_2": tune.choice([64, 128, 256]), "lr": tune.loguniform(1e-4, 1e-1), "batch_size": tune.choice([32, 64, 128]), "experiment_name": experiment_name, "tracking_uri": mlflow.get_tracking_uri(), "data_dir": os.path.join(tempfile.gettempdir(), "mnist_data_"), "num_epochs": num_epochs, } trainable = tune.with_parameters( train_mnist_tune, data_dir=data_dir, num_epochs=num_epochs, num_gpus=gpus_per_trial, ) tuner = tune.Tuner( tune.with_resources(trainable, resources={"cpu": 1, "gpu": gpus_per_trial}), tune_config=tune.TuneConfig( metric="loss", mode="min", num_samples=num_samples, ), run_config=tune.RunConfig( name="tune_mnist", ), param_space=config, ) results = tuner.fit() print("Best hyperparameters found were: ", results.get_best_result().config) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument( "--smoke-test", action="store_true", help="Finish quickly for testing" ) args, _ = parser.parse_known_args() if args.smoke_test: tune_mnist( num_samples=1, num_epochs=1, gpus_per_trial=0, tracking_uri=os.path.join(tempfile.gettempdir(), "mlruns"), ) else: tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0)