mnist_ptl_mini.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import math
  2. import os
  3. import pytorch_lightning as pl
  4. import torch
  5. from datasets import load_dataset
  6. from filelock import FileLock
  7. from torch.nn import functional as F
  8. from torch.utils.data import DataLoader
  9. from torchmetrics import Accuracy
  10. from torchvision import transforms
  11. from ray import tune
  12. from ray.tune.integration.pytorch_lightning import TuneReportCheckpointCallback
  13. PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
  14. class MNISTDataModule(pl.LightningDataModule):
  15. def __init__(self, batch_size: int, data_dir: str = PATH_DATASETS):
  16. super().__init__()
  17. self.data_dir = data_dir
  18. self.transform = transforms.Compose(
  19. [
  20. transforms.ToTensor(),
  21. transforms.Normalize((0.1307,), (0.3081,)),
  22. ]
  23. )
  24. self.batch_size = batch_size
  25. self.dims = (1, 28, 28)
  26. self.num_classes = 10
  27. def prepare_data(self):
  28. # download
  29. with FileLock(os.path.expanduser("~/.data.lock")):
  30. load_dataset("ylecun/mnist", cache_dir=self.data_dir)
  31. def setup(self, stage=None):
  32. dataset = load_dataset("ylecun/mnist", cache_dir=self.data_dir)
  33. def transform_fn(sample):
  34. return (self.transform(sample["image"]), sample["label"])
  35. self.mnist_train = [transform_fn(sample) for sample in dataset["train"]]
  36. self.mnist_val = [transform_fn(sample) for sample in dataset["test"]]
  37. def train_dataloader(self):
  38. return DataLoader(self.mnist_train, batch_size=self.batch_size)
  39. def val_dataloader(self):
  40. return DataLoader(self.mnist_val, batch_size=self.batch_size)
  41. class LightningMNISTClassifier(pl.LightningModule):
  42. def __init__(self, config, data_dir=None):
  43. super(LightningMNISTClassifier, self).__init__()
  44. self.data_dir = data_dir or os.getcwd()
  45. self.lr = config["lr"]
  46. layer_1, layer_2 = config["layer_1"], config["layer_2"]
  47. self.batch_size = config["batch_size"]
  48. # mnist images are (1, 28, 28) (channels, width, height)
  49. self.layer_1 = torch.nn.Linear(28 * 28, layer_1)
  50. self.layer_2 = torch.nn.Linear(layer_1, layer_2)
  51. self.layer_3 = torch.nn.Linear(layer_2, 10)
  52. self.accuracy = Accuracy(task="multiclass", num_classes=10, top_k=1)
  53. def forward(self, x):
  54. batch_size, channels, width, height = x.size()
  55. x = x.view(batch_size, -1)
  56. x = self.layer_1(x)
  57. x = torch.relu(x)
  58. x = self.layer_2(x)
  59. x = torch.relu(x)
  60. x = self.layer_3(x)
  61. x = torch.log_softmax(x, dim=1)
  62. return x
  63. def configure_optimizers(self):
  64. return torch.optim.Adam(self.parameters(), lr=self.lr)
  65. def training_step(self, train_batch, batch_idx):
  66. x, y = train_batch
  67. logits = self.forward(x)
  68. loss = F.nll_loss(logits, y)
  69. acc = self.accuracy(logits, y)
  70. self.log("ptl/train_loss", loss)
  71. self.log("ptl/train_accuracy", acc)
  72. return loss
  73. def validation_step(self, val_batch, batch_idx):
  74. x, y = val_batch
  75. logits = self.forward(x)
  76. loss = F.nll_loss(logits, y)
  77. acc = self.accuracy(logits, y)
  78. return {"val_loss": loss, "val_accuracy": acc}
  79. def validation_epoch_end(self, outputs):
  80. avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
  81. avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
  82. self.log("ptl/val_loss", avg_loss)
  83. self.log("ptl/val_accuracy", avg_acc)
  84. def train_mnist_tune(config, num_epochs=10, num_gpus=0):
  85. data_dir = os.path.abspath("./data")
  86. model = LightningMNISTClassifier(config, data_dir)
  87. with FileLock(os.path.expanduser("~/.data.lock")):
  88. dm = MNISTDataModule(data_dir=data_dir, batch_size=config["batch_size"])
  89. metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"}
  90. trainer = pl.Trainer(
  91. max_epochs=num_epochs,
  92. # If fractional GPUs passed in, convert to int.
  93. gpus=math.ceil(num_gpus),
  94. enable_progress_bar=False,
  95. callbacks=[
  96. TuneReportCheckpointCallback(
  97. metrics, on="validation_end", save_checkpoints=False
  98. )
  99. ],
  100. )
  101. trainer.fit(model, dm)
  102. def tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0):
  103. config = {
  104. "layer_1": tune.choice([32, 64, 128]),
  105. "layer_2": tune.choice([64, 128, 256]),
  106. "lr": tune.loguniform(1e-4, 1e-1),
  107. "batch_size": tune.choice([32, 64, 128]),
  108. }
  109. trainable = tune.with_parameters(
  110. train_mnist_tune, num_epochs=num_epochs, num_gpus=gpus_per_trial
  111. )
  112. tuner = tune.Tuner(
  113. tune.with_resources(trainable, resources={"cpu": 1, "gpu": gpus_per_trial}),
  114. tune_config=tune.TuneConfig(
  115. metric="loss",
  116. mode="min",
  117. num_samples=num_samples,
  118. ),
  119. run_config=tune.RunConfig(
  120. name="tune_mnist",
  121. ),
  122. param_space=config,
  123. )
  124. results = tuner.fit()
  125. print("Best hyperparameters found were: ", results.get_best_result().config)
  126. if __name__ == "__main__":
  127. import argparse
  128. parser = argparse.ArgumentParser()
  129. parser.add_argument(
  130. "--smoke-test", action="store_true", help="Finish quickly for testing"
  131. )
  132. args, _ = parser.parse_known_args()
  133. if args.smoke_test:
  134. tune_mnist(num_samples=1, num_epochs=1, gpus_per_trial=0)
  135. else:
  136. tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0)