mlflow_ptl.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. """An example showing how to use Pytorch Lightning training, Ray Tune
  2. HPO, and MLflow autologging all together."""
  3. import os
  4. import tempfile
  5. import mlflow
  6. import pytorch_lightning as pl
  7. from ray import tune
  8. from ray.air.integrations.mlflow import setup_mlflow
  9. from ray.tune.examples.mnist_ptl_mini import LightningMNISTClassifier, MNISTDataModule
  10. from ray.tune.integration.pytorch_lightning import TuneReportCallback
  11. def train_mnist_tune(config, data_dir=None, num_epochs=10, num_gpus=0):
  12. setup_mlflow(
  13. config,
  14. experiment_name=config.get("experiment_name", None),
  15. tracking_uri=config.get("tracking_uri", None),
  16. )
  17. model = LightningMNISTClassifier(config, data_dir)
  18. dm = MNISTDataModule(
  19. data_dir=data_dir, num_workers=1, batch_size=config["batch_size"]
  20. )
  21. metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"}
  22. mlflow.pytorch.autolog()
  23. trainer = pl.Trainer(
  24. max_epochs=num_epochs,
  25. gpus=num_gpus,
  26. progress_bar_refresh_rate=0,
  27. callbacks=[TuneReportCallback(metrics, on="validation_end")],
  28. )
  29. trainer.fit(model, dm)
  30. def tune_mnist(
  31. num_samples=10,
  32. num_epochs=10,
  33. gpus_per_trial=0,
  34. tracking_uri=None,
  35. experiment_name="ptl_autologging_example",
  36. ):
  37. data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_")
  38. # Download data
  39. MNISTDataModule(data_dir=data_dir, batch_size=32).prepare_data()
  40. # Set the MLflow experiment, or create it if it does not exist.
  41. mlflow.set_tracking_uri(tracking_uri)
  42. mlflow.set_experiment(experiment_name)
  43. config = {
  44. "layer_1": tune.choice([32, 64, 128]),
  45. "layer_2": tune.choice([64, 128, 256]),
  46. "lr": tune.loguniform(1e-4, 1e-1),
  47. "batch_size": tune.choice([32, 64, 128]),
  48. "experiment_name": experiment_name,
  49. "tracking_uri": mlflow.get_tracking_uri(),
  50. "data_dir": os.path.join(tempfile.gettempdir(), "mnist_data_"),
  51. "num_epochs": num_epochs,
  52. }
  53. trainable = tune.with_parameters(
  54. train_mnist_tune,
  55. data_dir=data_dir,
  56. num_epochs=num_epochs,
  57. num_gpus=gpus_per_trial,
  58. )
  59. tuner = tune.Tuner(
  60. tune.with_resources(trainable, resources={"cpu": 1, "gpu": gpus_per_trial}),
  61. tune_config=tune.TuneConfig(
  62. metric="loss",
  63. mode="min",
  64. num_samples=num_samples,
  65. ),
  66. run_config=tune.RunConfig(
  67. name="tune_mnist",
  68. ),
  69. param_space=config,
  70. )
  71. results = tuner.fit()
  72. print("Best hyperparameters found were: ", results.get_best_result().config)
  73. if __name__ == "__main__":
  74. import argparse
  75. parser = argparse.ArgumentParser()
  76. parser.add_argument(
  77. "--smoke-test", action="store_true", help="Finish quickly for testing"
  78. )
  79. args, _ = parser.parse_known_args()
  80. if args.smoke_test:
  81. tune_mnist(
  82. num_samples=1,
  83. num_epochs=1,
  84. gpus_per_trial=0,
  85. tracking_uri=os.path.join(tempfile.gettempdir(), "mlruns"),
  86. )
  87. else:
  88. tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0)