import math import os import pytorch_lightning as pl import torch from datasets import load_dataset from filelock import FileLock from torch.nn import functional as F from torch.utils.data import DataLoader from torchmetrics import Accuracy from torchvision import transforms from ray import tune from ray.tune.integration.pytorch_lightning import TuneReportCheckpointCallback PATH_DATASETS = os.environ.get("PATH_DATASETS", ".") class MNISTDataModule(pl.LightningDataModule): def __init__(self, batch_size: int, data_dir: str = PATH_DATASETS): super().__init__() self.data_dir = data_dir self.transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), ] ) self.batch_size = batch_size self.dims = (1, 28, 28) self.num_classes = 10 def prepare_data(self): # download with FileLock(os.path.expanduser("~/.data.lock")): load_dataset("ylecun/mnist", cache_dir=self.data_dir) def setup(self, stage=None): dataset = load_dataset("ylecun/mnist", cache_dir=self.data_dir) def transform_fn(sample): return (self.transform(sample["image"]), sample["label"]) self.mnist_train = [transform_fn(sample) for sample in dataset["train"]] self.mnist_val = [transform_fn(sample) for sample in dataset["test"]] def train_dataloader(self): return DataLoader(self.mnist_train, batch_size=self.batch_size) def val_dataloader(self): return DataLoader(self.mnist_val, batch_size=self.batch_size) class LightningMNISTClassifier(pl.LightningModule): def __init__(self, config, data_dir=None): super(LightningMNISTClassifier, self).__init__() self.data_dir = data_dir or os.getcwd() self.lr = config["lr"] layer_1, layer_2 = config["layer_1"], config["layer_2"] self.batch_size = config["batch_size"] # mnist images are (1, 28, 28) (channels, width, height) self.layer_1 = torch.nn.Linear(28 * 28, layer_1) self.layer_2 = torch.nn.Linear(layer_1, layer_2) self.layer_3 = torch.nn.Linear(layer_2, 10) self.accuracy = Accuracy(task="multiclass", num_classes=10, top_k=1) def forward(self, x): batch_size, channels, width, height = x.size() x = x.view(batch_size, -1) x = self.layer_1(x) x = torch.relu(x) x = self.layer_2(x) x = torch.relu(x) x = self.layer_3(x) x = torch.log_softmax(x, dim=1) return x def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=self.lr) def training_step(self, train_batch, batch_idx): x, y = train_batch logits = self.forward(x) loss = F.nll_loss(logits, y) acc = self.accuracy(logits, y) self.log("ptl/train_loss", loss) self.log("ptl/train_accuracy", acc) return loss def validation_step(self, val_batch, batch_idx): x, y = val_batch logits = self.forward(x) loss = F.nll_loss(logits, y) acc = self.accuracy(logits, y) return {"val_loss": loss, "val_accuracy": acc} def validation_epoch_end(self, outputs): avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean() self.log("ptl/val_loss", avg_loss) self.log("ptl/val_accuracy", avg_acc) def train_mnist_tune(config, num_epochs=10, num_gpus=0): data_dir = os.path.abspath("./data") model = LightningMNISTClassifier(config, data_dir) with FileLock(os.path.expanduser("~/.data.lock")): dm = MNISTDataModule(data_dir=data_dir, batch_size=config["batch_size"]) metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"} trainer = pl.Trainer( max_epochs=num_epochs, # If fractional GPUs passed in, convert to int. gpus=math.ceil(num_gpus), enable_progress_bar=False, callbacks=[ TuneReportCheckpointCallback( metrics, on="validation_end", save_checkpoints=False ) ], ) trainer.fit(model, dm) def tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0): config = { "layer_1": tune.choice([32, 64, 128]), "layer_2": tune.choice([64, 128, 256]), "lr": tune.loguniform(1e-4, 1e-1), "batch_size": tune.choice([32, 64, 128]), } trainable = tune.with_parameters( train_mnist_tune, num_epochs=num_epochs, num_gpus=gpus_per_trial ) tuner = tune.Tuner( tune.with_resources(trainable, resources={"cpu": 1, "gpu": gpus_per_trial}), tune_config=tune.TuneConfig( metric="loss", mode="min", num_samples=num_samples, ), run_config=tune.RunConfig( name="tune_mnist", ), param_space=config, ) results = tuner.fit() print("Best hyperparameters found were: ", results.get_best_result().config) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument( "--smoke-test", action="store_true", help="Finish quickly for testing" ) args, _ = parser.parse_known_args() if args.smoke_test: tune_mnist(num_samples=1, num_epochs=1, gpus_per_trial=0) else: tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0)